diff --git a/paddle/fluid/framework/new_executor/interpreter/static_build.cc b/paddle/fluid/framework/new_executor/interpreter/static_build.cc index d91c59f60627b25444fb91643a7895537215d20a..c54bdef618fbff4bb4168d310a8d3d1f130336ae 100644 --- a/paddle/fluid/framework/new_executor/interpreter/static_build.cc +++ b/paddle/fluid/framework/new_executor/interpreter/static_build.cc @@ -40,20 +40,6 @@ std::set OpsCanSkipedFakeAllocInStaticBuild = { "fetch_v2", "nop"}; -// Cannot static analysis these Ops' output dtype or backend because their -// kernels have not moved to PHI yet. -std::set OpsWithFluidKernelNeedMoveToPhi = { - "cudnn_lstm", - "dequantize", - "distributed_fused_lamb", - "fused_batch_norm_act", - "fused_batch_norm_act_grad", - "fusion_group", - "pow2_decay_with_linear_warmup", - "sequence_mask", - "sequence_pool", - "stft"}; - std::set StaticBuildBlackList = { "batch_norm" /*: to handle reserve_space output*/, "cinn_instruction_run" /*: to handle subgraph infermeta*/, @@ -95,8 +81,7 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) { bool has_fluid_kernel = OperatorWithKernel::AllOpKernels().count(op_type); bool has_structured_kernel = phi::KernelFactory::Instance().HasStructuredKernel(op_type); - bool need_move_to_phi = (has_fluid_kernel || has_structured_kernel) && - OpsWithFluidKernelNeedMoveToPhi.count(op_type); + bool need_move_to_phi = (has_fluid_kernel || has_structured_kernel); KernelCode kernel_code = (in_black_list << 7) + (is_operator_base << 6) + (is_custom_op << 5) + @@ -162,6 +147,11 @@ bool TensorShouldBeFakeInitialized(const OperatorBase& op, return false; } + if (op_type == "distributed_fused_lamb" && parameter_name == "ParamOut") { + VLOG(2) << "Skip fake initialization for: " << parameter_name; + return false; + } + if (op_type == "fake_quantize_range_abs_max") { if (op.Attr("is_test") && (parameter_name == "OutScale" || parameter_name == "OutScales")) { @@ -467,11 +457,14 @@ void FakeInitializeOutputsForFunctionKernel( } else if (op_type == "layer_norm") { dtype = InferMPDType(runtime_ctx, "X"); } else if (op_type == "reduce_sum") { + phi::DataType in_dtype = GetInputDType(runtime_ctx, "X"); int dtype_attr = op.Attr("out_dtype"); if (dtype_attr != -1) { dtype = phi::TransToPhiDataType(dtype_attr); + if (dtype == DataType::UNDEFINED) { + dtype = in_dtype; + } } else { - phi::DataType in_dtype = GetInputDType(runtime_ctx, "X"); dtype = (in_dtype == DataType::BOOL || in_dtype == DataType::INT32) ? DataType::INT64 @@ -489,7 +482,6 @@ void FakeInitializeOutputsForFunctionKernel( // analyze layout phi::DataLayout layout = tensor_arg_def.layout; - FakeInitializeTensorBase(dev_ctx, place, dtype, layout, out_tensor); } } diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 28a8d9564ec0b70bd25008bb9ea130ca35d99033..d8224cb0dd72bd4f1fbb175be19d2db6055d822e 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -44,7 +44,6 @@ void TensorCopyImpl(const TENSOR& src, TensorCopyImpl(src_copy, dst_place, ctx, dst); return; } - VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " << dst_place; src.check_memory_size(); @@ -325,7 +324,6 @@ void TensorCopySync(const phi::DenseTensor& src, << dst_place; return; } - auto size = src.numel() * phi::SizeOf(src.dtype()); if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index a3cd48cc50c0dbf0b9a2acc5a72a9f2f97f28810..e61512924f81d1c533b3362653614c2e85b4ee0f 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -15,8 +15,12 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" + +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -25,75 +29,6 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM"); - OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM"); - OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM"); - - OP_INOUT_CHECK(ctx->HasOutput("Reserve"), "Output", "Reserve", "CudnnLSTM"); - OP_INOUT_CHECK( - ctx->HasOutput("StateOut"), "Output", "StateOut", "CudnnLSTM"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CudnnLSTM"); - OP_INOUT_CHECK(ctx->HasOutput("LastH"), "Output", "LastH", "CudnnLSTM"); - OP_INOUT_CHECK(ctx->HasOutput("LastC"), "Output", "LastC", "CudnnLSTM"); - - auto in_dims = ctx->GetInputDim("Input"); - auto init_h_dims = ctx->GetInputDim("InitH"); - auto init_c_dims = ctx->GetInputDim("InitC"); - - PADDLE_ENFORCE_EQ(in_dims.size(), - 3, - platform::errors::InvalidArgument( - "The rank of Input in CudnnLSTM must be 3. But " - "received Input's rank is %d.", - in_dims.size())); - PADDLE_ENFORCE_EQ(init_h_dims.size(), - 3, - platform::errors::InvalidArgument( - "The rank of InitH in CudnnLSTM must be 3. But " - "received InitH's rank is %d.", - init_h_dims.size())); - - if (ctx->HasInput("SequenceLength")) { - auto seq_dims = ctx->GetInputDim("SequenceLength"); - PADDLE_ENFORCE_EQ( - in_dims[1], - seq_dims[0], - platform::errors::InvalidArgument( - "The size of SequenceLength has to equal the batch_size. But " - "received batch_size is %d and the size of SequenceLength is %d.", - in_dims[1], - seq_dims[0])); - } - - PADDLE_ENFORCE_EQ( - in_dims[1], - init_h_dims[1], - platform::errors::InvalidArgument( - "The in_dims[1] (Input dims) and init_h_dims[1] (InitH " - "dims) should be equal. But " - "received in_dims[1] is %d and init_h_dims[1] is %d.", - in_dims[1], - init_h_dims[1])); - - PADDLE_ENFORCE_EQ(init_c_dims, - init_h_dims, - platform::errors::InvalidArgument( - "The InitC dims and InitH " - "dims should be equal. But " - "received init_c_dims is %d and init_h_dims is %d.", - init_c_dims, - init_h_dims)); - - auto out_dims = in_dims; - auto hidden_size = ctx->Attrs().Get("hidden_size"); - bool is_bidirec = ctx->Attrs().Get("is_bidirec"); - out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size; - ctx->SetOutputDim("Out", out_dims); - ctx->SetOutputDim("LastH", init_c_dims); - ctx->SetOutputDim("LastC", init_h_dims); - } - protected: phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -295,12 +230,18 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker { } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(cudnn_lstm, + CudnnLSTMInferShapeFunctor, + PD_INFER_META(phi::CudnnLSTMInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker, ops::CudnnLSTMGradOpMaker, - ops::CudnnLSTMGradOpMaker); + ops::CudnnLSTMGradOpMaker, + CudnnLSTMInferShapeFunctor); + REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); // TODO(Shixiaowei02) Add ModifyInput support diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index 6ca9ae3b76d9d486a256a7de99e3d8ae6e43411f..cad7e38ba1c1a311f96c56a901972f91cd7fbf1c 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -270,7 +270,7 @@ static const T *GetInputTensorPtr(const DenseTensor *in_tensor, PADDLE_ENFORCE_NOT_NULL( in_tensor, phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name)); - if (in_tensor->IsInitialized()) { + if (in_tensor->initialized()) { if (numel) *numel = in_tensor->numel(); return in_tensor->data(); } else { @@ -286,7 +286,7 @@ static T *GetSameInOutTensorPtr(const Context &dev_ctx, const char *in_name, const char *out_name, int64_t *numel = nullptr) { - if (in_tensor == nullptr || !in_tensor->IsInitialized()) { + if (in_tensor == nullptr || !in_tensor->initialized()) { PADDLE_ENFORCE_EQ( AllowNotExist, true, @@ -1160,7 +1160,7 @@ struct VisitDTypeFunctor { static std::string GetMinMaxStr(const phi::DenseTensor *x) { if (x == nullptr) return "null"; - if (!x->IsInitialized()) return "not_inited"; + if (!x->initialized()) return "not_inited"; if (x->place().GetType() != phi::AllocationType::GPU) return "CPUTensor"; std::string str; VisitDTypeFunctor functor(x, &str); @@ -1354,9 +1354,7 @@ void DistributedFusedLambKernel( #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto stream = dev_ctx.stream(); auto place = dev_ctx.GetPlace(); - found_inf->Resize({1}); - // Step 1: Get fp16 param and grad tensors int64_t fp16_numel; auto *fp16_param_data = @@ -1374,7 +1372,6 @@ void DistributedFusedLambKernel( } else { fp16_param_data = nullptr; } - // Step 2: Get fp32 param and grad tensors int64_t fp32_numel = 0; auto *fp32_param_data = @@ -1389,7 +1386,6 @@ void DistributedFusedLambKernel( phi::errors::InvalidArgument( "The element number in FP32FusedParam should be not " "less than FP16FusedParam.")); - fp32_numel -= fp16_numel; // the FP32FusedParam contains fp32 param and // fp16 master weight bool has_fp32_param = (fp32_numel > 0); @@ -1404,7 +1400,6 @@ void DistributedFusedLambKernel( phi::errors::InvalidArgument( "Either FP32FusedGrad or FP16FusedGrad cannot be NULL.")); } - auto numel = fp32_numel + fp16_numel; VLOG(1) << "numel = " << numel << " , fp32_numel = " << fp32_numel << " , fp16_numel = " << fp16_numel; @@ -1426,7 +1421,7 @@ void DistributedFusedLambKernel( acc_step, phi::errors::InvalidArgument( "Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1.")); - bool is_initialized = acc_step->IsInitialized(); + bool is_initialized = acc_step->initialized(); int64_t *acc_step_data; if (is_initialized) { acc_step_data = dev_ctx.template HostAlloc(acc_step); @@ -1437,14 +1432,13 @@ void DistributedFusedLambKernel( *acc_step_data = 1; } int64_t rounded_step = (*acc_step_data) % acc_steps; - float *fp32_acc_grad_data = nullptr; if (has_fp32_param) { PADDLE_ENFORCE_NOT_NULL(fp32_acc_grad, phi::errors::InvalidArgument( "Output(FP32AccFusedGrad) cannot be nullptr " "when Attr(acc_steps) > 1.")); - if (!fp32_acc_grad->IsInitialized()) { + if (!fp32_acc_grad->initialized()) { fp32_acc_grad->Resize({static_cast(fp32_numel)}); fp32_acc_grad_data = dev_ctx.template Alloc(fp32_acc_grad); } else { @@ -1459,7 +1453,7 @@ void DistributedFusedLambKernel( phi::errors::InvalidArgument( "Output(FP16AccFusedGrad) cannot be nullptr " "when Attr(acc_steps) > 1.")); - if (!fp16_acc_grad->IsInitialized()) { + if (!fp16_acc_grad->initialized()) { auto acc_grad_size = use_master_acc_grad ? (3 * fp16_numel) : fp16_numel; fp16_acc_grad->Resize({static_cast(acc_grad_size)}); @@ -1544,11 +1538,9 @@ void DistributedFusedLambKernel( } } } - stop_update->Resize({1}); auto *stop_update_data = dev_ctx.template HostAlloc(stop_update); auto *found_inf_cpu = dev_ctx.template HostAlloc(found_inf); - if (rounded_step != 0) { *stop_update_data = true; *found_inf_cpu = false; @@ -1599,7 +1591,6 @@ void DistributedFusedLambKernel( int64_t partial_numel = 0; auto *moment1_data = GetSameInOutTensorPtr( dev_ctx, &moment1, moment1_out, "Moment1", "Moment1Out", &partial_numel); - PADDLE_ENFORCE_EQ(numel % partial_numel, 0, phi::errors::InvalidArgument( @@ -1629,16 +1620,13 @@ void DistributedFusedLambKernel( "exactly by the device number %d.", fp16_numel, num_devices)); - auto *moment2_data = GetSameInOutTensorPtr( dev_ctx, &moment2, moment2_out, "Moment2", "Moment2Out"); auto *beta1_pow_data = GetSameInOutTensorPtr( dev_ctx, &beta1_pow, beta1_pow_out, "Beta1Pow", "Beta1PowOut"); auto *beta2_pow_data = GetSameInOutTensorPtr( dev_ctx, &beta2_pow, beta2_pow_out, "Beta2Pow", "Beta2PowOut"); - auto *found_inf_data = dev_ctx.template Alloc(found_inf); - // Step 5: Get attributes weight_decay, beta1, beta2, epsilon, // max_grad_norm, ring_id, // use_master_param_norm, is_grad_scaled_by_nranks @@ -1668,7 +1656,6 @@ void DistributedFusedLambKernel( paddle::platform::NCCLCommContext::Instance().Get(ring_ids[0], place); global_comm = nccl_comm_handle->comm(); global_rank = nccl_comm_handle->rank(); - if (local_shard) { auto *local_nccl_comm_handle = paddle::platform::NCCLCommContext::Instance().Get(ring_ids[1], place); @@ -1684,11 +1671,9 @@ void DistributedFusedLambKernel( local_rank = global_rank; } } - memory_utils::Buffer grad_norm_square_buffer(place); auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc(2); memory_utils::Buffer cub_tmp_buffer(place); - memory_utils::Buffer sum_grad_buffer(place); float *fp32_sum_grad; dtype::float16 *fp16_sum_grad; @@ -1721,7 +1706,6 @@ void DistributedFusedLambKernel( fp32_sum_grad = const_cast(fp32_grad_data); fp16_sum_grad = const_cast(fp16_grad_data); } - float rescale_grad = 1.0f; if (!is_grad_scaled_by_nranks) { rescale_grad /= nranks; @@ -1845,7 +1829,6 @@ void DistributedFusedLambKernel( } else { fp16_scale = cub_tmp_buffer.Alloc(1); } - float clip_scale = 1.0f; if (is_grad_scaled_by_nranks) { clip_scale *= nranks; @@ -1988,7 +1971,6 @@ void DistributedFusedLambKernel( external_comm, stream, dev_ctx); - NCCLReduceScatterWithScale( fp16_grad_data, fp16_sum_grad + local_rank * fp16_numel_each_device, @@ -2064,9 +2046,7 @@ void DistributedFusedLambKernel( auto *param_offsets_data = param_offsets.data(); const auto *fp32_partial_offsets_data = fp32_partial_offsets.data(); const auto *fp16_partial_offsets_data = fp16_partial_offsets.data(); - auto *step_data = step->data(); - VLOG(1) << "FusedParamOffsets: " << FlattenToString(param_offsets_data, param_offsets.numel(), @@ -2079,7 +2059,6 @@ void DistributedFusedLambKernel( << FlattenToString(fp16_partial_offsets_data, fp16_partial_offsets.numel(), fp16_partial_offsets.place()); - memory_utils::Buffer trust_ratio_div_buffer(place); auto *trust_ratio_div = trust_ratio_div_buffer.Alloc(partial_numel); auto fp32_offset = local_rank * fp32_numel_each_device; @@ -2178,7 +2157,6 @@ void DistributedFusedLambKernel( fp16_local_param_num, param_square_norm + fp16_local_start_idx); } - MultiTensorL2Norm(place, stream, trust_ratio_div, @@ -2191,7 +2169,6 @@ void DistributedFusedLambKernel( fp16_partial_offsets_data, fp16_local_param_num, trust_ratio_div_square_norm + fp16_local_start_idx); - VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: " << FlattenToString(trust_ratio_div_square_norm, param_num, place); if (num_devices > 1) { @@ -2296,6 +2273,12 @@ PD_REGISTER_KERNEL(distributed_fused_lamb, ALL_LAYOUT, phi::fusion::DistributedFusedLambKernel, float) { + kernel->InputAt(10).SetBackend(phi::Backend::CPU); + kernel->InputAt(11).SetBackend(phi::Backend::CPU); + kernel->InputAt(12).SetBackend(phi::Backend::CPU); + kernel->InputAt(13).SetBackend(phi::Backend::CPU); + kernel->InputAt(14).SetBackend(phi::Backend::CPU); + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT16); kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index c3441688e4a7cedc598825c0262cc6e05de8a612..79ed182a1e15d1a0884e57ce33794201ffdbd6aa 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -978,6 +978,85 @@ void ConcatInferMeta(const std::vector& x, out->share_lod(*x.at(0)); } +void CudnnLSTMInferMeta( + const MetaTensor& x, + const MetaTensor& init_h, + const MetaTensor& init_c, + const MetaTensor& w, + const paddle::optional>& weight_list, + const MetaTensor& sequence_length, + float dropout_prob, + bool is_bidirec, + int hidden_size, + int num_layers, + bool is_test, + int seed, + MetaTensor* out, + MetaTensor* last_h, + MetaTensor* last_c, + MetaTensor* reserve, + MetaTensor* state_out) { + auto in_dims = x.dims(); + auto init_h_dims = init_h.dims(); + + auto init_c_dims = init_c.dims(); + + PADDLE_ENFORCE_EQ(in_dims.size(), + 3, + phi::errors::InvalidArgument( + "The rank of Input in CudnnLSTM must be 3. But " + "received Input's rank is %d.", + in_dims.size())); + PADDLE_ENFORCE_EQ(init_h_dims.size(), + 3, + phi::errors::InvalidArgument( + "The rank of InitH in CudnnLSTM must be 3. But " + "received InitH's rank is %d.", + init_h_dims.size())); + + if (sequence_length) { + auto seq_dims = sequence_length.dims(); + PADDLE_ENFORCE_EQ( + in_dims[1], + seq_dims[0], + phi::errors::InvalidArgument( + "The size of SequenceLength has to equal the batch_size. But " + "received batch_size is %d and the size of SequenceLength is %d.", + in_dims[1], + seq_dims[0])); + } + + PADDLE_ENFORCE_EQ(in_dims[1], + init_h_dims[1], + phi::errors::InvalidArgument( + "The in_dims[1] (Input dims) and init_h_dims[1] (InitH " + "dims) should be equal. But " + "received in_dims[1] is %d and init_h_dims[1] is %d.", + in_dims[1], + init_h_dims[1])); + + PADDLE_ENFORCE_EQ(init_c_dims, + init_h_dims, + phi::errors::InvalidArgument( + "The InitC dims and InitH " + "dims should be equal. But " + "received init_c_dims is %d and init_h_dims is %d.", + init_c_dims, + init_h_dims)); + + auto out_dims = in_dims; + out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size; + out->set_dims(out_dims); + out->set_dtype(x.dtype()); + last_h->set_dims(init_c_dims); + last_h->set_dtype(x.dtype()); + last_c->set_dims(init_h_dims); + last_c->set_dtype(x.dtype()); + + reserve->set_dtype(phi::DataType::UINT8); + state_out->set_dtype(phi::DataType::UINT8); +} + inline int ConvOutputSize( int input_size, int filter_size, int dilation, int padding, int stride) { const int dkernel = dilation * (filter_size - 1) + 1; diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index e4ca81cb873c02f430f9aba07fda54f4f12890fd..d9aef9f26168591e0f78b4a43dd1ed936e478fe5 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -239,6 +239,25 @@ void ConcatInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void CudnnLSTMInferMeta( + const MetaTensor& x, + const MetaTensor& init_h, + const MetaTensor& init_c, + const MetaTensor& w, + const paddle::optional>& weight_list, + const MetaTensor& sequence_length, + float dropout_prob, + bool is_bidirec, + int hidden_size, + int num_layers, + bool is_test, + int seed, + MetaTensor* out, + MetaTensor* last_h, + MetaTensor* last_c, + MetaTensor* reserve, + MetaTensor* state_out); + void DeformableConvInferMeta(const MetaTensor& x, const MetaTensor& offset, const MetaTensor& filter, diff --git a/paddle/phi/kernels/cpu/eigh_kernel.cc b/paddle/phi/kernels/cpu/eigh_kernel.cc index 2b722177d2ef38283d9fbe2588a90ffbf809a2cd..53f1b368efbcbe95ca985fd85931d224b322e819 100644 --- a/paddle/phi/kernels/cpu/eigh_kernel.cc +++ b/paddle/phi/kernels/cpu/eigh_kernel.cc @@ -43,5 +43,4 @@ PD_REGISTER_KERNEL(eigh, phi::dtype::complex, phi::dtype::complex) { kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); - kernel->OutputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); } diff --git a/paddle/phi/kernels/cpu/pow2_decay_with_linear_warmup_kernel.cc b/paddle/phi/kernels/cpu/pow2_decay_with_linear_warmup_kernel.cc index 9bdf6bb2c86ab6500d395ed6f2d9e23e3d0a97d3..816376f04f488d4907b926de719391884c419eb6 100644 --- a/paddle/phi/kernels/cpu/pow2_decay_with_linear_warmup_kernel.cc +++ b/paddle/phi/kernels/cpu/pow2_decay_with_linear_warmup_kernel.cc @@ -22,4 +22,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup, ALL_LAYOUT, phi::Pow2DecayWithLinearWarmupKernel, float, - double) {} + double) { + kernel->InputAt(1).SetDataType(phi::DataType::INT64); + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/cpu/sequence_mask_kernel.cc b/paddle/phi/kernels/cpu/sequence_mask_kernel.cc index 20f999f4be228471cfb1ef55a45116ad7d1e0fbf..3496c0315886697f50da308e0e9bd0565e4ccaf0 100644 --- a/paddle/phi/kernels/cpu/sequence_mask_kernel.cc +++ b/paddle/phi/kernels/cpu/sequence_mask_kernel.cc @@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(sequence_mask, float, double, int, - int64_t) {} + int64_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu b/paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu index eca9bcf18cf0ec583a702b0c6cc7865f3d50c8cd..eee5a4b84b54a60e3a0aacb02d3f8a97a9dc78c3 100644 --- a/paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu @@ -101,4 +101,6 @@ PD_REGISTER_KERNEL(fusion_group, phi::fusion::FusionGroupKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/gpu/abs_kernel.cu b/paddle/phi/kernels/gpu/abs_kernel.cu index 8f55f49daf3de3c3d0804e11dfb3e0abbf31cdf7..85f05041adbfb990f14fd8b3e0b52d9d17d21cf3 100644 --- a/paddle/phi/kernels/gpu/abs_kernel.cu +++ b/paddle/phi/kernels/gpu/abs_kernel.cu @@ -77,5 +77,5 @@ PD_REGISTER_KERNEL(abs, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) { - kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); + kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); } diff --git a/paddle/phi/kernels/gpu/adadelta_kernel.cu b/paddle/phi/kernels/gpu/adadelta_kernel.cu index 9270609d034666d9eb7ab3aeffcf531ec064d1d1..b627b4449ef7cdf756173ed1888a884bfc78b893 100644 --- a/paddle/phi/kernels/gpu/adadelta_kernel.cu +++ b/paddle/phi/kernels/gpu/adadelta_kernel.cu @@ -24,4 +24,10 @@ PD_REGISTER_KERNEL(adadelta, phi::AdadeltaKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + } +} diff --git a/paddle/phi/kernels/gpu/adamax_kernel.cu b/paddle/phi/kernels/gpu/adamax_kernel.cu index eca39cb1a97a061d128a2e4403a9db5ba09b5de7..2cfeddc6ceeba3078abc00309d5dc7e2022f192f 100644 --- a/paddle/phi/kernels/gpu/adamax_kernel.cu +++ b/paddle/phi/kernels/gpu/adamax_kernel.cu @@ -132,4 +132,10 @@ PD_REGISTER_KERNEL(adamax, phi::AdamaxKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + } +} diff --git a/paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu b/paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu index e3f2b780f3ffc3e5477f4bdab939836d90602c02..f3a03727e0bc45d22b762f667bd860d8f8e12972 100644 --- a/paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu +++ b/paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu @@ -176,7 +176,7 @@ void CudnnLSTMKernel( int seq_length = x.dims()[0]; int batch_size = x.dims()[1]; int input_size = x.dims()[2]; - bool state_initialized = state_out->IsInitialized() ? true : false; + bool state_initialized = state_out->initialized() ? true : false; size_t workspace_size; size_t reserve_size; @@ -188,7 +188,7 @@ void CudnnLSTMKernel( auto stream = ctx.stream(); auto *running_w = w.get_ptr(); if (is_test && running_w != nullptr) { - w_initialized = running_w->IsInitialized() ? true : false; + w_initialized = running_w->initialized() ? true : false; weight_numel = running_w->numel(); } if (!w_initialized) { @@ -362,12 +362,14 @@ void CudnnLSTMKernel( #ifdef PADDLE_WITH_HIP PD_REGISTER_KERNEL(cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float) { + kernel->InputAt(5).SetDataType(phi::DataType::INT32); kernel->OutputAt(3).SetDataType(phi::DataType::UINT8); kernel->OutputAt(4).SetDataType(phi::DataType::UINT8); } #else PD_REGISTER_KERNEL( cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float, double) { + kernel->InputAt(5).SetDataType(phi::DataType::INT32); kernel->OutputAt(3).SetDataType(phi::DataType::UINT8); kernel->OutputAt(4).SetDataType(phi::DataType::UINT8); } diff --git a/paddle/phi/kernels/gpu/eigh_kernel.cu b/paddle/phi/kernels/gpu/eigh_kernel.cu index 0544b6bdf238ad5bc074a3c67c06b52c48620ad4..f3b33ad5c9878529992c5d3b03c6909c2ec0026a 100644 --- a/paddle/phi/kernels/gpu/eigh_kernel.cu +++ b/paddle/phi/kernels/gpu/eigh_kernel.cu @@ -46,7 +46,6 @@ PD_REGISTER_KERNEL(eigh, // cuda_only phi::dtype::complex, phi::dtype::complex) { kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); - kernel->OutputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); } #endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu b/paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu index 57162bc7fb2b3e02b4bf61260e8d6e423574fdb9..5ac45f95e5be310480d20a52c972d693ed153f43 100644 --- a/paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu +++ b/paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu @@ -22,4 +22,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup, ALL_LAYOUT, phi::Pow2DecayWithLinearWarmupKernel, float, - double) {} + double) { + kernel->InputAt(1).SetDataType(phi::DataType::INT64); + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/sequence_mask_kernel.cu b/paddle/phi/kernels/gpu/sequence_mask_kernel.cu index 3e37e3e2419c2c47e00ea67eaa957714ecd80d7c..619d0a42356f3df8c8e46a3330aa5bab062148e9 100644 --- a/paddle/phi/kernels/gpu/sequence_mask_kernel.cu +++ b/paddle/phi/kernels/gpu/sequence_mask_kernel.cu @@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(sequence_mask, float, double, int, - int64_t) {} + int64_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h b/paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h index bbca911b404c42bb9c39623d46634cce6b78a8e8..da28f52f6173b827d799796d384ba4958d3e1855 100644 --- a/paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h +++ b/paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h @@ -85,7 +85,7 @@ void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx, phi::errors::InvalidArgument( "Input(Step) and Output(StepOut) must be the same.")); PADDLE_ENFORCE_EQ( - step.IsInitialized(), + step.initialized(), true, phi::errors::InvalidArgument("Input(Step) must be initialized.")); diff --git a/paddle/phi/kernels/xpu/amp_kernel.cc b/paddle/phi/kernels/xpu/amp_kernel.cc index 0c00baf9170d53c0fe7cb638249c3317e279dce0..c2f9984ee4145393f7a1a776506ca41eccbfe3d4 100644 --- a/paddle/phi/kernels/xpu/amp_kernel.cc +++ b/paddle/phi/kernels/xpu/amp_kernel.cc @@ -280,7 +280,13 @@ PD_REGISTER_KERNEL(update_loss_scaling, ALL_LAYOUT, phi::UpdateLossScalingKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + } + kernel->OutputAt(2).SetDataType(phi::DataType::INT32); + kernel->OutputAt(3).SetDataType(phi::DataType::INT32); +} PD_REGISTER_KERNEL(check_finite_and_unscale, XPU, diff --git a/paddle/phi/kernels/xpu/pool_kernel.cc b/paddle/phi/kernels/xpu/pool_kernel.cc index 928bdbbdef4526f50f0be05b140e3d322a09add5..466adade072c7ad5025ddab874fa2e5b7b08f640 100644 --- a/paddle/phi/kernels/xpu/pool_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_kernel.cc @@ -369,4 +369,6 @@ PD_REGISTER_KERNEL(max_pool2d_with_index, ALL_LAYOUT, phi::MaxPool2dWithIndexKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/xpu/pow2_decay_with_linear_warmup_kernel.cc b/paddle/phi/kernels/xpu/pow2_decay_with_linear_warmup_kernel.cc index 8661f2ac2cb2202dea16a99d9c27f44d0f58287d..bfda5688bb3407a3d4bdd009a7c79783d70326fc 100644 --- a/paddle/phi/kernels/xpu/pow2_decay_with_linear_warmup_kernel.cc +++ b/paddle/phi/kernels/xpu/pow2_decay_with_linear_warmup_kernel.cc @@ -41,7 +41,7 @@ void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx, phi::errors::InvalidArgument( "Input(Step) and Output(StepOut) must be the same.")); PADDLE_ENFORCE_EQ( - step.IsInitialized(), + step.initialized(), true, phi::errors::InvalidArgument("Input(Step) must be initialized.")); @@ -68,4 +68,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup, XPU, ALL_LAYOUT, phi::Pow2DecayWithLinearWarmupKernel, - float) {} + float) { + kernel->InputAt(1).SetDataType(phi::DataType::INT64); + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc index 48a339ab51ce424bb3db31dc326e006bd2c4b4ed..9d8261f4246e10d055cd171ed853fad96881e1d3 100644 --- a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc @@ -45,4 +45,6 @@ PD_REGISTER_KERNEL(sum_raw, phi::dtype::float16, int8_t, int, - int64_t) {} + int64_t) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +}