diff --git a/paddle/fluid/operators/assign_op_xpu.cc b/paddle/fluid/operators/assign_op_xpu.cc deleted file mode 100644 index be78337aa1562f521ada7a54c8302bdcc95846de..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/assign_op_xpu.cc +++ /dev/null @@ -1,166 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include - -#include "paddle/fluid/operators/assign_op.h" - -namespace paddle { -namespace framework { -class OpDesc; -class Variable; -} // namespace framework -namespace imperative { -class OpBase; -} // namespace imperative -} // namespace paddle - -namespace paddle { -namespace operators { - -class AssignOp : public framework::OperatorWithKernel { - public: - AssignOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - void InferShape(framework::InferShapeContext *ctx) const override { - if (ctx->HasInput("X")) { - auto type = ctx->GetInputsVarType("X")[0]; - if (type == framework::proto::VarType::SELECTED_ROWS || - type == framework::proto::VarType::LOD_TENSOR) { - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - if (type == framework::proto::VarType::LOD_TENSOR) { - ctx->ShareLoD("X", /*->*/ "Out"); - } - } else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) { - if (ctx->IsRuntime()) { - // The runtime output shape is determined in kernel. - return; - } else { - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - } - } - } - } - - protected: - framework::OpKernelType GetKernelTypeForVar( - const std::string &var_name, - const framework::Tensor &tensor, - const framework::OpKernelType &expected_kernel_type) const override { - return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); - } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - const framework::Variable *var = ctx.InputVar("X"); - if (var->IsType()) { - auto t_arr = var->Get(); - // NOTE(liym27): Support an empty tensor array as Input. - // And set the kernel type is float. - if (t_arr.size() == 0) { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); - } - } - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); - } -}; - -class AssignInferVarType : public framework::VarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - ctx->SyncTypeAndDataType("X", "Out"); - } -}; - -class AssignKernel { - public: - void operator()(const framework::ExecutionContext &ctx) const { - auto *x = ctx.InputVar("X"); - if (x == nullptr) { - return; - } - PADDLE_ENFORCE_EQ( - ctx.HasOutput("Out"), - true, - platform::errors::NotFound("Output(Out) of assign_op is not found.")); - auto *out = ctx.OutputVar("Out"); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(ctx.GetPlace()); - - framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); - } -}; - -class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(LoDTensor, SelectedRows or LoDTensorArray) The input variable " - "could be LoDTensor, SelectedRows or LoDTensorArray.") - .AsDispensable(); - AddOutput("Out", - "(LoDTensor, SelectedRows or LoDTensorArray) The type of output " - "is the same as input X."); - AddComment(R"DOC(Assign Operator - -Out = X, when type in [LoDTensor/SelectedRows/LoDTensorArray] -raise error if the type is not listed above. -)DOC"); - } -}; - -template -class AssignGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("assign"); - op->SetInput("X", this->OutputGrad("Out")); - op->SetOutput("Out", this->InputGrad("X")); - } -}; - -DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"}); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_XPU_KERNEL_FUNCTOR(assign, - float, - ops::AssignKernel, - double, - ops::AssignKernel, - int, - ops::AssignKernel, - int64_t, - ops::AssignKernel, - bool, - ops::AssignKernel); -#endif diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index 16e9bb384b5f3f78e97203c760646fe3fe7df634..bf030e6fb4b5fd45f6d2a6a800a3ac0c3ae1bdd1 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -158,3 +158,20 @@ PD_REGISTER_KERNEL(assign_value, float, int64_t) {} #endif + +#ifdef PADDLE_WITH_XPU +PD_REGISTER_GENERAL_KERNEL( + assign, XPU, ALL_LAYOUT, phi::AssignKernel, ALL_DTYPE) {} +PD_REGISTER_GENERAL_KERNEL(assign_raw, + XPU, + ALL_LAYOUT, + phi::AssignRawKernel, + ALL_DTYPE) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} +PD_REGISTER_GENERAL_KERNEL(assign_array, + XPU, + ALL_LAYOUT, + phi::AssignArrayKernel, + ALL_DTYPE) {} +#endif diff --git a/paddle/phi/kernels/selected_rows/assign_kernel.cc b/paddle/phi/kernels/selected_rows/assign_kernel.cc index f0c0ffb591a11de99847d31a191bd50229f88f29..993c5f81d347f45b541f80c81ce009209f5ed471 100644 --- a/paddle/phi/kernels/selected_rows/assign_kernel.cc +++ b/paddle/phi/kernels/selected_rows/assign_kernel.cc @@ -47,3 +47,11 @@ PD_REGISTER_GENERAL_KERNEL(assign_sr, phi::sr::AssignKernel, ALL_DTYPE) {} #endif + +#ifdef PADDLE_WITH_XPU +PD_REGISTER_GENERAL_KERNEL(assign_sr, + XPU, + ALL_LAYOUT, + phi::sr::AssignKernel, + ALL_DTYPE) {} +#endif