diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 312a0305244684c88e8926d2a71db377b0dd6be1..e09429bc42957562423681fc4b0a1a3ea70e85cb 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -11,7 +11,7 @@ elseif(NEW_RELEASE_ALL) add_definitions(-DNEW_RELEASE_ALL) set(paddle_known_gpu_archs "35 50 52 60 61 70 75 80 86") set(paddle_known_gpu_archs10 "35 50 52 60 61 70 75") - set(paddle_known_gpu_archs11 "35 50 52 60 61 70 75 80") + set(paddle_known_gpu_archs11 "35 50 60 61 70 75 80") elseif(NEW_RELEASE_PYPI) message("Using New Release Strategy - Cubin Packge") add_definitions(-DNEW_RELEASE_PYPI) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 97114b9fe1189a9d9774eeb223e69926c3e11cbc..1c33d1c2f4f0b52d21584e099c5bf3a0dabd5f6e 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -696,14 +696,15 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, FUNCTION_TEMPLATE = """ std::vector> {}::operator()(const std::vector>& grads, bool create_graph) {{ // Call grad_api function + VLOG(3) << \"Finally State Running: \" << \"{}\"; auto grad_api_returns = {}::{}({}); {} }} """ node_definition_str = FUNCTION_TEMPLATE.format( - grad_node_name, grad_api_namespace, bwd_api_name, grad_api_args_str, - returns_str) + grad_node_name, grad_node_name, grad_api_namespace, bwd_api_name, + grad_api_args_str, returns_str) return node_definition_str diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index ebd3333c5265990a8ae2fb6840113bd0ea4d4766..0e9dc19c2e310e75e32d52d011a65630ea5b967d 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -612,7 +612,9 @@ std::vector RunBackward( for (size_t i = 0; i < edges.size(); i++) { for (size_t j = 0; j < edges[i].size(); j++) { const Edge& edge = edges[i][j]; - + if (!edge.IsInitialized()) { + continue; + } auto edge_rank = edge.GetEdgeRankInfo(); // Since we make edge has as same rank as bwd outputs, we indexing them // with diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index 891ad4d8983b5b37b31ab5f5f980e74ccff47069..1d44d842b0825aa96380c947c67082fbcb5e1642 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -63,6 +63,8 @@ void GradNodeBase::AddEdges(std::vector* metas, size_t slot_id) { adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), meta->OutRankInfo()); + } else { + adj_edges_[slot_id].emplace_back(); } } } @@ -85,6 +87,8 @@ void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) { adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), meta->OutRankInfo()); + } else { + adj_edges_[slot_id].emplace_back(); } } diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 4b21a193ee021f06538e1a11bbffb898376739a7..28c12717a24b0c89b8a3b6544124ad6533d6c70d 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -257,12 +257,22 @@ class Edge { } // Currently we use grad_node_ to identify if a edge is initialized. - bool IsInitialized() const { return grad_node_.get(); } + bool IsInitialized() const { + if (!grad_node_) { + return false; + } else { + if (!(grad_node_.get())) { + return false; + } else { + return true; + } + } + } private: size_t in_slot_id_; size_t in_rank_; - std::shared_ptr grad_node_; + std::shared_ptr grad_node_{nullptr}; }; } // namespace egr diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc index 47a00a93a647253305080bde2d8c98eb735513d6..48ae080783d112c7e11daebe984de70925f5bbe2 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cc +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -203,14 +203,6 @@ REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad, ops::LookupTableV2GradOpNoBufferVarsInferer, ops::LookupTableV2OpGradVarTypeInference); -REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel, - ops::LookupTableV2Kernel, - ops::LookupTableV2Kernel); -REGISTER_OP_CPU_KERNEL( - lookup_table_v2_grad, ops::LookupTableV2GradKernel, - ops::LookupTableV2GradKernel, - ops::LookupTableV2GradKernel); - /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(lookup_table_v2) .AddCheckpoint( diff --git a/paddle/fluid/operators/lookup_table_v2_op.cu b/paddle/fluid/operators/lookup_table_v2_op.cu index d40b2643785706e843dbd9812e74ca0aa134f7b5..74d089e23a82c6ea15988d9c3c0c3e5b42da8b2e 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cu +++ b/paddle/fluid/operators/lookup_table_v2_op.cu @@ -235,13 +235,3 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(lookup_table_v2, ops::LookupTableV2CUDAKernel, - ops::LookupTableV2CUDAKernel, - ops::LookupTableV2CUDAKernel); -REGISTER_OP_CUDA_KERNEL(lookup_table_v2_grad, - ops::LookupTableV2GradCUDAKernel, - ops::LookupTableV2GradCUDAKernel, - ops::LookupTableV2GradCUDAKernel); diff --git a/paddle/fluid/operators/reverse_op.cc b/paddle/fluid/operators/reverse_op.cc index 98a1610be607e8bcd6d14a25a45d1856a64dbe8a..975eecafc06a615e7d958fe1271679719f68c241 100644 --- a/paddle/fluid/operators/reverse_op.cc +++ b/paddle/fluid/operators/reverse_op.cc @@ -12,60 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reverse_op.h" #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { class ReverseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Reverse"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Reverse"); - - auto x_var_type = ctx->GetInputsVarType("X")[0]; - const auto& axis = ctx->Attrs().Get>("axis"); - if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) { - PADDLE_ENFORCE_EQ( - axis.size(), 1, - platform::errors::InvalidArgument( - "The size of axis must be 1 when the Input(X) is LoDTensorArray, " - "but received %d.", - axis.size())); - PADDLE_ENFORCE_EQ(axis[0], 0, platform::errors::InvalidArgument( - "The value of axis should be 1 when " - "the Input(X) is LoDTensorArray, " - "but received %d.", - axis[0])); - // In runtime, shape is determined by RunImpl. - if (!ctx->IsRuntime()) { - const auto& x_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", x_dims); - } - return; - } - const auto& x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_NE(axis.empty(), true, platform::errors::InvalidArgument( - "'axis' can not be empty.")); - for (int a : axis) { - PADDLE_ENFORCE_LT(a, x_dims.size(), - paddle::platform::errors::OutOfRange( - "The axis must be less than input tensor's rank. " - "but got %d >= %d", - a, x_dims.size())); - PADDLE_ENFORCE_GE( - a, -x_dims.size(), - paddle::platform::errors::OutOfRange( - "The axis must be greater than the negative number of " - "input tensor's rank, but got %d < %d", - a, -x_dims.size())); - } - ctx->SetOutputDim("Out", x_dims); - } }; class ReverseOpVarTypeInference : public framework::VarTypeInference { @@ -134,23 +94,10 @@ class ReverseGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(reverse, ReverseInferShapeFunctor, + PD_INFER_META(phi::ReverseInferMeta)); REGISTER_OPERATOR(reverse, ops::ReverseOp, ops::ReverseOpMaker, ops::ReverseGradMaker, ops::ReverseGradMaker, - ops::ReverseOpVarTypeInference); + ops::ReverseOpVarTypeInference, ReverseInferShapeFunctor); REGISTER_OPERATOR(reverse_grad, ops::ReverseOp, ops::ReverseOpVarTypeInference); -REGISTER_OP_CPU_KERNEL( - reverse, ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel); - -REGISTER_OP_CUDA_KERNEL( - reverse, ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel); diff --git a/paddle/fluid/operators/reverse_op.h b/paddle/fluid/operators/reverse_op.h deleted file mode 100644 index d5e331e2fe5f697c7f1604a82bc1dd08bfbd276e..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reverse_op.h +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" - -namespace paddle { -namespace operators { -template -struct ReverseFunctor { - void operator()(const DeviceContext& context, const framework::LoDTensor& in, - framework::LoDTensor* out, const std::vector& axis) { - Eigen::DSizes reverse_axis; - for (int i = 0; i < Rank; ++i) { - reverse_axis[i] = false; - } - for (int a : axis) { - if (a >= 0) { - reverse_axis[a] = true; - } else { - reverse_axis[Rank + a] = true; - } - } - - auto in_eigen = framework::EigenTensor::From(in); - auto out_eigen = framework::EigenTensor::From(*out); - auto& dev = *context.eigen_device(); - - EigenReverse, T, Rank>::Eval( - dev, out_eigen, in_eigen, reverse_axis); - } -}; - -template -class ReverseKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x_var = context.InputVar("X"); - const auto& axis = context.Attr>("axis"); - if (x_var->IsType()) { - auto& x_array = x_var->Get(); - auto* out_array = context.Output("Out"); - - out_array->resize(x_array.size()); - for (size_t offset = 0; offset < x_array.size(); offset++) { - auto& x_tensor = x_array.at(offset); - PADDLE_ENFORCE_GT( - x_tensor.memory_size(), 0, - platform::errors::PreconditionNotMet( - "The input LoDTensorArray X[%d] holds no memory.", offset)); - auto out_offset = x_array.size() - offset - 1; - auto* out_tensor = &out_array->at(out_offset); - - out_tensor->set_lod(x_tensor.lod()); - paddle::framework::TensorCopy(x_tensor, context.GetPlace(), out_tensor); - } - return; - } - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - int rank = x->dims().size(); - auto& dev_ctx = context.template device_context(); - - switch (rank) { - case 1: - ReverseFunctor functor1; - functor1(dev_ctx, *x, out, axis); - break; - case 2: - ReverseFunctor functor2; - functor2(dev_ctx, *x, out, axis); - break; - case 3: - ReverseFunctor functor3; - functor3(dev_ctx, *x, out, axis); - break; - case 4: - ReverseFunctor functor4; - functor4(dev_ctx, *x, out, axis); - break; - case 5: - ReverseFunctor functor5; - functor5(dev_ctx, *x, out, axis); - break; - case 6: - ReverseFunctor functor6; - functor6(dev_ctx, *x, out, axis); - break; - default: - PADDLE_THROW(paddle::platform::errors::OutOfRange( - "The reserve operator does not support input tensors" - "whose ranks are greater than 6.")); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/infrt/dialect/phi/CMakeLists.txt b/paddle/infrt/dialect/phi/CMakeLists.txt index 67f6bb8a2d7bbfa604614e4909169c08ea18e1b3..436ff0a40480cfe2d0907309a8eecf8cdc32adb2 100644 --- a/paddle/infrt/dialect/phi/CMakeLists.txt +++ b/paddle/infrt/dialect/phi/CMakeLists.txt @@ -5,6 +5,10 @@ endif() add_subdirectory(ir) add_subdirectory(pass) +add_executable(phi-ir-exec phi_ir_exec.cc) +target_link_libraries(phi-ir-exec infrt) + + add_executable(phi-exec phi_exec.cc) target_link_libraries(phi-exec infrt) diff --git a/paddle/infrt/dialect/phi/ir/infrt_phi_base.td b/paddle/infrt/dialect/phi/ir/infrt_phi_base.td index 8e21283183d036ac26c117a0a209ba92d1f9febc..376d62deecee7cfb958f7dbb180b5936091f8acf 100644 --- a/paddle/infrt/dialect/phi/ir/infrt_phi_base.td +++ b/paddle/infrt/dialect/phi/ir/infrt_phi_base.td @@ -18,8 +18,8 @@ def PHI_Dialect : Dialect { def PhiOpTrait : NativeOpTrait<"PhiOpTrait">; -class PHI_Type traits = []> - : TypeDef {} +class PHI_Type traits = [], string baseCppClass = "::mlir::Type"> + : TypeDef {} def Allocator : PHI_Type<"Allocator"> { let mnemonic = "allocator"; diff --git a/paddle/infrt/host_context/paddle_mlir.cc b/paddle/infrt/host_context/paddle_mlir.cc index e161dc47075bb3e87399477b3112a4c4c57cec1c..ec12815e3ce94f52b987e845363880cfc3896387 100644 --- a/paddle/infrt/host_context/paddle_mlir.cc +++ b/paddle/infrt/host_context/paddle_mlir.cc @@ -16,6 +16,7 @@ #include "paddle/infrt/dialect/infrt/ir/basic_kernels.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" #include "paddle/infrt/dialect/pd/common/pd_ops_info.h" +#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" MLIRModelGenImpl::MLIRModelGenImpl() : context_(infrt::Global::getMLIRContext()), builder_(context_) { @@ -24,6 +25,8 @@ MLIRModelGenImpl::MLIRModelGenImpl() context_->getOrLoadDialect(); context_->getOrLoadDialect(); context_->getOrLoadDialect<::infrt::InfrtDialect>(); + context_->getOrLoadDialect<::infrt::phi::PHIDialect>(); + context_->getOrLoadDialect<::infrt::phi::PHIDenseTensorDialect>(); module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(context_)); } @@ -79,7 +82,7 @@ mlir::FuncOp MLIRModelGenImpl::UpdateModelModule( llvm::SmallVector MLIRModelGenImpl::GetModelInputsType( const infrt::paddle::framework_proto::ProgramDesc &program) { llvm::SmallVector operandTypes; - operandTypes.push_back(infrt::DenseHostTensorMapType::get(context_)); + operandTypes.push_back(infrt::phi::DenseTensorMapType::get(context_)); for (auto &op_desc : main_block_.ops()) { if (op_desc.type() != "feed") continue; for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) { @@ -180,7 +183,7 @@ void MLIRModelGenImpl::UpdateModelParams( &precision_); mlir::Type type_ = infrt::DenseTensorType::get( context_, infrt::TargetType::CPU, precision_, infrt::LayoutType::ANY); - auto op = builder_.create( + auto op = builder_.create<::infrt::phi::TensorMapGetTensorOp>( mlir::UnknownLoc::get(context_), type_, map, name); params_map_.insert(std::pair( var_desc.name(), op.getOperation()->getResult(0))); diff --git a/paddle/infrt/tests/CMakeLists.txt b/paddle/infrt/tests/CMakeLists.txt index 58543a6864258bd6c0153150bb535262d9a8f00d..6f839cdc3954939e8c8d4792facac5a284d25f3f 100644 --- a/paddle/infrt/tests/CMakeLists.txt +++ b/paddle/infrt/tests/CMakeLists.txt @@ -6,3 +6,4 @@ add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle DEPENDS infrtopt infrtexec) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir) +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir) diff --git a/paddle/infrt/tests/dialect/phi/linear_cpu.mlir.in b/paddle/infrt/tests/dialect/phi/linear_cpu.mlir.in new file mode 100644 index 0000000000000000000000000000000000000000..7ca33fa10a90d4dffef02526b2c19744e388a6aa --- /dev/null +++ b/paddle/infrt/tests/dialect/phi/linear_cpu.mlir.in @@ -0,0 +1,19 @@ +// RUN: infrtexec -i %s +module { + func @main_graph(%arg0: !phi.dense_tensor_map, %arg1: !infrt.dense_tensor) -> !infrt.dense_tensor { + %0 = phi_dt.tensor_map_get_tensor(%arg0) {name = "linear_0.w_0"} -> !infrt.dense_tensor + %1 = phi_dt.tensor_map_get_tensor(%arg0) {name = "linear_0.b_0"} -> !infrt.dense_tensor + %2 = "phi_dt.create_context.cpu"() : () -> !phi.context + %5 = "phi_cpu.matmul.float32.any"(%2, %arg1, %0) {trans_x = false, trans_y = false} : (!phi.context, !infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + %7 = "phi_cpu.add.float32.any"(%2, %5, %1): (!phi.context, !infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + infrt.return %7 : !infrt.dense_tensor + } + func @main() { + %ctx = "phi_dt.create_context.cpu" (): () -> !phi.context + %1 = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision, layout=#infrt.layout, lod=[1:i64], dims=[16:i64, 784:i64]}: (!phi.context) -> (!infrt.dense_tensor) + %map = phi_dt.load_combined_params(){model_path="@CMAKE_BINARY_DIR@/linear/linear.pdmodel",params_path="@CMAKE_BINARY_DIR@/linear/linear.pdiparams"} + %2 = infrt.call@main_graph(%map, %1) : (!phi.dense_tensor_map, !infrt.dense_tensor) -> !infrt.dense_tensor + phi_dt.print_tensor (%2 : !infrt.dense_tensor) + infrt.return + } +} diff --git a/paddle/infrt/tests/model/linear.py b/paddle/infrt/tests/model/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..602e067365b87318ecb847d13832810aa1db4593 --- /dev/null +++ b/paddle/infrt/tests/model/linear.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# example 1: save layer +import numpy as np +import paddle +import paddle.nn as nn +import paddle.optimizer as opt + +BATCH_SIZE = 16 +BATCH_NUM = 4 +EPOCH_NUM = 4 + +IMAGE_SIZE = 784 +CLASS_NUM = 10 + + +# define a random dataset +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + +class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + + @paddle.jit.to_static + def forward(self, x): + return self._linear(x) + + +def train(layer, loader, loss_fn, opt): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + opt.clear_grad() + + +# 1. train & save model. + +# create network +layer = LinearNet() +loss_fn = nn.CrossEntropyLoss() +adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters()) + +# create data loader +dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) +loader = paddle.io.DataLoader( + dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2) + +# train +train(layer, loader, loss_fn, adam) + +# save +path = "linear/linear" +paddle.jit.save(layer, path) diff --git a/paddle/phi/api/include/context_pool.h b/paddle/phi/api/include/context_pool.h index 754833a2ddab3601f61069a916aea05181425c8f..a2983d9c2aa656e072b7cef010e220201ae3857f 100644 --- a/paddle/phi/api/include/context_pool.h +++ b/paddle/phi/api/include/context_pool.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include + #include "paddle/phi/common/place.h" #include "paddle/phi/core/macros.h" #include "paddle/utils/flat_hash_map.h" @@ -58,21 +60,22 @@ class DeviceContextPool { public: static DeviceContextPool& Instance(); - const phi::DeviceContext* Get(const Place& place) const; + const phi::DeviceContext* Get(const Place& place); phi::DeviceContext* GetMutable(const Place& place); template - const typename DefaultDeviceContextType::TYPE* Get( - const Place& place) const { + const typename DefaultDeviceContextType::TYPE* Get(const Place& place) { return reinterpret_cast::TYPE*>( Get(place)); } private: - DeviceContextPool(); + DeviceContextPool() = default; + paddle::flat_hash_map context_map_; + std::mutex mutex_; DISABLE_COPY_AND_ASSIGN(DeviceContextPool); }; diff --git a/paddle/phi/api/lib/context_pool.cc b/paddle/phi/api/lib/context_pool.cc index d1408a88d6ff784039f9e45393d9aec9ff37df2a..07ac9822d3310e2c3976296168b2c4527e082274 100644 --- a/paddle/phi/api/lib/context_pool.cc +++ b/paddle/phi/api/lib/context_pool.cc @@ -25,12 +25,17 @@ DeviceContextPool& DeviceContextPool::Instance() { return g_device_context_pool; } -const phi::DeviceContext* DeviceContextPool::Get(const Place& place) const { +const phi::DeviceContext* DeviceContextPool::Get(const Place& place) { auto it = context_map_.find(place); - PADDLE_ENFORCE_NE( - it, - context_map_.end(), - phi::errors::NotFound("The DeviceContext of %s does not exists.", place)); + if (it == context_map_.end()) { + // only when we need the specific DeviceContext, get and cache it + auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place); + { + std::lock_guard lock(mutex_); + context_map_[place] = dev_ctx; + } + return dev_ctx; + } return it->second; } @@ -38,28 +43,5 @@ phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) { return const_cast(Get(place)); } -DeviceContextPool::DeviceContextPool() { - // We need to make sure that the correct value exists - // whenever we get the DeviceContext from DeviceContextPool - const auto& device_contexts = - paddle::platform::DeviceContextPool::Instance().device_contexts(); - for (const auto& pair : device_contexts) { - // only get CPU and GPU DeviceContext now, add other DeviceContext type - // later if needed - if (platform::is_cpu_place(pair.first) -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - || - platform::is_gpu_place(pair.first)) { -#else - ) { -#endif - const phi::DeviceContext* dev_ctx = pair.second.get().get(); - VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", " - << dev_ctx << "}"; - context_map_[pair.first] = dev_ctx; - } - } -} - } // namespace experimental } // namespace paddle diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 36a049eca0f30c4d5d292d23b94cbead53c71208..5221076f10daa8de680cae1fd271fc3bd68ba797 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -23,8 +23,6 @@ limitations under the License. */ #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/funcs/common_shape.h" -#include "paddle/phi/kernels/cpu/conv_util.h" - namespace phi { namespace detail { @@ -469,6 +467,31 @@ void ConvInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void ConvInferInferMeta(const MetaTensor& input, + const MetaTensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + MetaTensor* out, + MetaConfig config) { + ConvInferMeta(input, + filter, + strides, + paddings, + paddding_algorithm, + groups, + dilations, + data_format, + /*use_addto=*/false, + /*workspace_size_MB=*/512, // useless in infermeta + /*exhaustive_search=*/false, + out, + config); +} + void ConvTransposeInferMeta(const MetaTensor& x, const MetaTensor& filter, const std::vector& strides, @@ -1670,3 +1693,4 @@ void ValueCompareInferMeta(const MetaTensor& x, PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); PD_REGISTER_INFER_META_FN(conv2d, phi::ConvInferMeta); +PD_REGISTER_INFER_META_FN(conv2d_infer, phi::ConvInferInferMeta); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 9a54c4c5fa62d4c58e527b9efbf2e977f72354ec..f9a939843775374eeddfac8f198f19f8c8dc10eb 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -83,6 +83,17 @@ void ConvInferMeta(const MetaTensor& input, MetaTensor* out, MetaConfig config = MetaConfig()); +void ConvInferInferMeta(const MetaTensor& input, + const MetaTensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void ConvTransposeInferMeta(const MetaTensor& x, const MetaTensor& filter, const std::vector& strides, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 7c5f38744f8923805d1e9b521c58813293cdce9b..80503dd2430927223dedd80d8e44c08473536997 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1240,6 +1240,33 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ReshapeInferMeta(x, shape, out, config); } +void ReverseInferMeta(const MetaTensor& x, + const std::vector& axis, + MetaTensor* out) { + PADDLE_ENFORCE_NE(axis.empty(), + true, + phi::errors::InvalidArgument("'axis' can not be empty.")); + const auto& x_dims = x.dims(); + for (int a : axis) { + PADDLE_ENFORCE_LT(a, + x_dims.size(), + phi::errors::OutOfRange( + "The axis must be less than input tensor's rank. " + "but got %d >= %d", + a, + x_dims.size())); + PADDLE_ENFORCE_GE( + a, + -x_dims.size(), + phi::errors::OutOfRange( + "The axis must be greater than the negative number of " + "input tensor's rank, but got %d < %d", + a, + -x_dims.size())); + } + out->share_meta(x); +} + void RollInferMeta(const MetaTensor& x, const ScalarArray& shifts, const std::vector& axis, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index d84283a65c4d19445dce61e9cf8ee6f70a83905f..0322a18fc3153b996e03aace0f705f1a776ad99f 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -198,6 +198,10 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void ReverseInferMeta(const MetaTensor& x, + const std::vector& axis, + MetaTensor* out); + void RollInferMeta(const MetaTensor& x, const ScalarArray& shifts, const std::vector& axis, diff --git a/paddle/phi/kernels/conv_kernel.cc b/paddle/phi/kernels/conv_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..7268384f401a1f9c30555aeda4521c35bb76a677 --- /dev/null +++ b/paddle/phi/kernels/conv_kernel.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/conv_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/platform/cudnn_workspace_helper.h" + +namespace phi { + +template +void ConvInferKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + DenseTensor* out) { + ConvKernel(dev_ctx, + input, + filter, + strides, + paddings, + paddding_algorithm, + groups, + dilations, + data_format, + /*use_addto=*/false, + /*workspace_size_MB=*/paddle::platform:: + GetDefaultConvWorkspaceSizeLimitMB(), + /*exhaustive_search=*/false, + out); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + conv2d_infer, CPU, ALL_LAYOUT, phi::ConvInferKernel, float, double) {} +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL( + conv2d_infer, GPU, ALL_LAYOUT, phi::ConvInferKernel, float, double) {} +#endif diff --git a/paddle/phi/kernels/conv_kernel.h b/paddle/phi/kernels/conv_kernel.h index eb0bfdd0275b5050054c620e722b0e7653fd678a..508b3a42a21addb8b29cbe19d00c93782120a4fe 100644 --- a/paddle/phi/kernels/conv_kernel.h +++ b/paddle/phi/kernels/conv_kernel.h @@ -64,4 +64,16 @@ void DepthwiseConvKernel(const Context& dev_ctx, bool fuse_relu, DenseTensor* out); +template +void ConvInferKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..21b3e6da8d9efdac1e5866ef3ac1aac580d5a0b8 --- /dev/null +++ b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc @@ -0,0 +1,220 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_grad_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +struct EmbeddingGradCPUFunctor { + EmbeddingGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + weight_grad_(weight_grad), + padding_idx_(padding_idx) {} + + template + void apply() { + DDim table_dim = weight_.dims(); + + auto ids = CopyIdsToVector(input_); + auto ids_num = static_cast(ids.size()); + + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + { + auto* d_output = &out_grad_; + auto* ids_data = ids.data(); + + int64_t N = table_dim[0]; + int64_t D = table_dim[1]; + + auto* d_output_data = d_output->template data(); + + dev_ctx_.template Alloc(weight_grad_); + auto* d_table_data = weight_grad_->data(); + + memset(d_table_data, 0, weight_grad_->numel() * sizeof(T)); + + for (int64_t i = 0; i < ids_num; ++i) { + if (padding_idx_ != kNoPadding && ids_data[i] == padding_idx_) { + // the gradient of padding_idx should be 0, already done by memset, so + // do nothing. + } else { + PADDLE_ENFORCE_LT( + ids_data[i], + N, + phi::errors::InvalidArgument( + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + N, + ids_data[i])); + PADDLE_ENFORCE_GE( + ids_data[i], + 0, + phi::errors::InvalidArgument( + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + N, + ids_data[i])); + for (int j = 0; j < D; ++j) { + d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j]; + } + } + } + } + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + const DenseTensor& out_grad_; + DenseTensor* weight_grad_; + int64_t padding_idx_; +}; + +template +void EmbeddingGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) { + EmbeddingGradCPUFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +template +struct EmbeddingSparseGradCPUFunctor { + EmbeddingSparseGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + weight_grad_(weight_grad), + padding_idx_(padding_idx) {} + + template + void apply() { + DDim table_dim = weight_.dims(); + + auto ids = CopyIdsToVector(input_); + auto ids_num = static_cast(ids.size()); + + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + auto* d_table = weight_grad_; + auto* d_output = &out_grad_; + d_table->set_rows(ids); + + auto* d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table_dim[1]}); + + dev_ctx_.template Alloc(d_table_value); + + d_table->set_height(table_dim[0]); + + auto* d_output_data = d_output->template data(); + auto* d_table_data = d_table_value->template data(); + + auto d_output_dims = d_output->dims(); + auto d_output_dims_2d = + flatten_to_2d(d_output_dims, d_output_dims.size() - 1); + PADDLE_ENFORCE_EQ(d_table_value->dims(), + d_output_dims_2d, + phi::errors::InvalidArgument( + "ShapeError: The shape of lookup_table@Grad and " + "output@Grad should be same. " + "But received lookup_table@Grad's shape = [%s], " + "output@Grad's shape = [%s].", + d_table_value->dims(), + d_output_dims_2d)); + memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + const DenseTensor& out_grad_; + SelectedRows* weight_grad_; + int64_t padding_idx_; +}; + +template +void EmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) { + EmbeddingSparseGradCPUFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(embedding_grad, + CPU, + ALL_LAYOUT, + phi::EmbeddingGradKernel, + float, + double, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(embedding_sparse_grad, + CPU, + ALL_LAYOUT, + phi::EmbeddingSparseGradKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/embedding_kernel.cc b/paddle/phi/kernels/cpu/embedding_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..76cc3814b0567087ef8e5d40fe4031ed6598a49b --- /dev/null +++ b/paddle/phi/kernels/cpu/embedding_kernel.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { + +template +struct EmbeddingCPUFunctor { + EmbeddingCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + int64_t padding_idx, + DenseTensor* out) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_(out), + padding_idx_(padding_idx) {} + + template + void apply() { + auto ids = CopyIdsToVector(input_); + auto ids_numel = static_cast(ids.size()); + + int64_t row_number = weight_.dims()[0]; + int64_t row_width = weight_.dims()[1]; + + auto* table = weight_.data(); + + dev_ctx_.template Alloc(out_); + auto* output = out_->data(); + + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_LT( + ids[i], + row_number, + phi::errors::InvalidArgument( + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, + ids[i])); + PADDLE_ENFORCE_GE( + ids[i], + 0, + phi::errors::InvalidArgument( + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, + ids[i])); + memcpy(output + i * row_width, + table + ids[i] * row_width, + row_width * sizeof(T)); + } + } + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + DenseTensor* out_; + int64_t padding_idx_; +}; + +template +void EmbeddingKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + int64_t padding_idx, + DenseTensor* out) { + EmbeddingCPUFunctor functor(ctx, input, weight, padding_idx, out); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(embedding, + CPU, + ALL_LAYOUT, + phi::EmbeddingKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/reverse_kernel.cc b/paddle/phi/kernels/cpu/reverse_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..43eff7c055090b982659959d86b9ae59372e89ee --- /dev/null +++ b/paddle/phi/kernels/cpu/reverse_kernel.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reverse_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/reverse_kernel_impl.h" + +PD_REGISTER_KERNEL(reverse, + CPU, + ALL_LAYOUT, + phi::ReverseKernel, + int, + uint8_t, + int64_t, + bool, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..d78477073ad03b1b39aaae00c16aed81ea7fd056 --- /dev/null +++ b/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc @@ -0,0 +1,224 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { + +template +struct SparseWeightEmbeddingGradCPUFunctor { + SparseWeightEmbeddingGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + weight_grad_(weight_grad), + padding_idx_(padding_idx) {} + + template + void apply() { + DDim table_dim = weight_.dims(); + + auto ids = CopyIdsToVector(input_); + auto ids_num = static_cast(ids.size()); + + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + { + auto* d_output = &out_grad_; + // auto d_table = weight_grad_; + auto* ids_data = ids.data(); + + int64_t N = table_dim[0]; + int64_t D = table_dim[1]; + + auto* d_output_data = d_output->template data(); + + dev_ctx_.template Alloc(weight_grad_); + auto* d_table_data = weight_grad_->data(); + + memset(d_table_data, 0, weight_grad_->numel() * sizeof(T)); + + for (int64_t i = 0; i < ids_num; ++i) { + if (padding_idx_ != kNoPadding && ids_data[i] == padding_idx_) { + // the gradient of padding_idx should be 0, already done by memset, so + // do nothing. + } else { + PADDLE_ENFORCE_LT( + ids_data[i], + N, + phi::errors::InvalidArgument( + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + N, + ids_data[i])); + PADDLE_ENFORCE_GE( + ids_data[i], + 0, + phi::errors::InvalidArgument( + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + N, + ids_data[i])); + for (int j = 0; j < D; ++j) { + d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j]; + } + } + } + } + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const SelectedRows& weight_; + const DenseTensor& out_grad_; + DenseTensor* weight_grad_; + int64_t padding_idx_; +}; + +template +struct SparseWeightEmbeddingSparseGradCPUFunctor { + SparseWeightEmbeddingSparseGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + weight_grad_(weight_grad), + padding_idx_(padding_idx) {} + + template + void apply() { + DDim table_dim = weight_.dims(); + + auto ids = CopyIdsToVector(input_); + auto ids_num = static_cast(ids.size()); + + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + auto* d_table = weight_grad_; + auto* d_output = &out_grad_; + d_table->set_rows(ids); + + auto* d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table_dim[1]}); + + dev_ctx_.template Alloc(d_table_value); + + d_table->set_height(table_dim[0]); + + auto* d_output_data = d_output->template data(); + auto* d_table_data = d_table_value->template data(); + + auto d_output_dims = d_output->dims(); + auto d_output_dims_2d = + phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1); + PADDLE_ENFORCE_EQ(d_table_value->dims(), + d_output_dims_2d, + phi::errors::InvalidArgument( + "ShapeError: The shape of lookup_table@Grad and " + "output@Grad should be same. " + "But received lookup_table@Grad's shape = [%s], " + "output@Grad's shape = [%s].", + d_table_value->dims(), + d_output_dims_2d)); + memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const SelectedRows& weight_; + const DenseTensor& out_grad_; + SelectedRows* weight_grad_; + int64_t padding_idx_; +}; + +template +void SparseWeightEmbeddingGradKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) { + SparseWeightEmbeddingGradCPUFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +template +void SparseWeightEmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) { + SparseWeightEmbeddingSparseGradCPUFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(sparse_weight_embedding_grad, + CPU, + ALL_LAYOUT, + phi::SparseWeightEmbeddingGradKernel, + float, + double, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(sparse_weight_embedding_sparse_grad, + CPU, + ALL_LAYOUT, + phi::SparseWeightEmbeddingSparseGradKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc b/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0f95d03888b8df825341c282e08f80dafc988a8 --- /dev/null +++ b/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { + +template +struct EmbeddingCPUSparseFunctor { + EmbeddingCPUSparseFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + int64_t padding_idx, + DenseTensor* out) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_(out), + padding_idx_(padding_idx) {} + + template + void apply() { + auto ids = CopyIdsToVector(input_); + auto ids_numel = static_cast(ids.size()); + + const auto& table_t = weight_; + auto output_t = out_; + int64_t row_width = table_t.value().dims()[1]; + const auto* table = table_t.value().template data(); + auto* output = dev_ctx_.template Alloc(output_t); + auto input_data_type = + paddle::framework::TransToProtoVarType(table_t.value().dtype()); + + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_GE( + ids[i], + 0, + phi::errors::InvalidArgument( + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0. But received %ld", + ids[i])); + auto id_index = table_t.Index(ids[i]); + PADDLE_ENFORCE_GE( + id_index, + 0, + phi::errors::InvalidArgument( + "the input key should be exists. But received %d.", id_index)); + + if (input_data_type == paddle::framework::proto::VarType::BF16) { + memcpy(output + i * row_width, + table + id_index * row_width, + row_width * sizeof(T)); + } else { + auto blas = phi::funcs::GetBlas(dev_ctx_); + blas.VCOPY( + row_width, table + id_index * row_width, output + i * row_width); + } + } + } + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const SelectedRows& weight_; + DenseTensor* out_; + int64_t padding_idx_; +}; + +template +void SparseWeightEmbeddingKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + int64_t padding_idx, + DenseTensor* out) { + EmbeddingCPUSparseFunctor functor( + ctx, input, weight, padding_idx, out); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(sparse_weight_embedding, + CPU, + ALL_LAYOUT, + phi::SparseWeightEmbeddingKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/embedding_grad_kernel.h b/paddle/phi/kernels/embedding_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..40ffe6ec886c447a1d5f762cdbe01c95edb39764 --- /dev/null +++ b/paddle/phi/kernels/embedding_grad_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void EmbeddingGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad); + +template +void EmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/embedding_kernel.h b/paddle/phi/kernels/embedding_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..cd7d675d6dc6cd7d71486437d9c56c4e73431af1 --- /dev/null +++ b/paddle/phi/kernels/embedding_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void EmbeddingKernel(const Context& ctx, + const DenseTensor& inputx, + const DenseTensor& weight, + int64_t padding_idx, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/batch_norm_utils.h b/paddle/phi/kernels/funcs/batch_norm_utils.h index 21ebae8487ffc3588034a8ea5feeab8ac1c47fa8..a7ed7d36eb1c41f688e875912c0a8648cb42cb03 100644 --- a/paddle/phi/kernels/funcs/batch_norm_utils.h +++ b/paddle/phi/kernels/funcs/batch_norm_utils.h @@ -36,8 +36,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context, in_dims_vec[3] = input->dims()[2]; in_dims_vec[4] = input->dims()[3]; transformed_input->Resize(make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - + context.template Alloc(transformed_input); } else if (dim == 2) { // input transformed_input->Resize(input->dims()); @@ -47,7 +46,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context, in_dims_vec[2] = input->dims()[1]; in_dims_vec[3] = input->dims()[2]; transformed_input->Resize(make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); + context.template Alloc(transformed_input); } else if (dim == 1) { transformed_input->Resize(input->dims()); @@ -55,7 +54,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context, in_dims_vec[1] = input->dims()[2]; in_dims_vec[2] = input->dims()[1]; transformed_input->Resize(make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); + context.template Alloc(transformed_input); } } @@ -74,7 +73,7 @@ inline void ResizeToChannelLast(const DeviceContext& context, in_dims_vec[3] = input->dims()[4]; in_dims_vec[4] = input->dims()[1]; transformed_input->Resize(make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); + context.template Alloc(transformed_input); } else if (dim == 2) { // input @@ -85,7 +84,7 @@ inline void ResizeToChannelLast(const DeviceContext& context, in_dims_vec[2] = input->dims()[3]; in_dims_vec[3] = input->dims()[1]; transformed_input->Resize(make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); + context.template Alloc(transformed_input); } else if (dim == 1) { transformed_input->Resize(input->dims()); @@ -93,7 +92,7 @@ inline void ResizeToChannelLast(const DeviceContext& context, in_dims_vec[1] = input->dims()[2]; in_dims_vec[2] = input->dims()[1]; transformed_input->Resize(make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); + context.template Alloc(transformed_input); } } diff --git a/paddle/phi/kernels/funcs/embedding_util.h b/paddle/phi/kernels/funcs/embedding_util.h new file mode 100644 index 0000000000000000000000000000000000000000..20c4ddca05460afbbd30491c9269ce935a3f611e --- /dev/null +++ b/paddle/phi/kernels/funcs/embedding_util.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +constexpr int64_t kNoPadding = -1; + +template +static std::vector CopyIdsToVector(const DenseTensor &ids) { + auto numel = ids.numel(); + const auto *src = ids.data(); + std::vector ret(numel); + if (std::is_same::value) { + std::memcpy(ret.data(), src, numel * sizeof(InT)); + } else { + for (decltype(numel) i = 0; i < numel; ++i) { + ret[i] = src[i]; + } + } + return ret; +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 2c9ee5ede010367697bb9477a536f807625fd02b..339c3536d7a7f476df0c2c46bf34ba48b73c07c3 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -359,8 +359,8 @@ void BatchNormGradRawKernel(const Context &ctx, } if (d_scale && d_bias) { - d_scale->mutable_data>(ctx.GetPlace()); - d_bias->mutable_data>(ctx.GetPlace()); + ctx.template Alloc>(d_scale); + ctx.template Alloc>(d_bias); } PADDLE_ENFORCE_EQ( @@ -569,8 +569,8 @@ void BatchNormGradRawKernel(const Context &ctx, /*activationDesc=*/nullptr, /*sizeInBytes=*/&workspace_size)); - workspace_ptr = workspace_tensor.mutable_data( - ctx.GetPlace(), transformed_x.type(), workspace_size); + workspace_tensor.Resize({static_cast(workspace_size)}); + workspace_ptr = ctx.template Alloc(&workspace_tensor); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnBatchNormalizationBackwardEx( @@ -594,12 +594,9 @@ void BatchNormGradRawKernel(const Context &ctx, /*dBnScaleBiasDesc=*/bn_param_desc_, /*bnScaleData=*/scale.template data>(), /*bnBiasData=*/nullptr, - /*dBnScaleData=*/d_scale - ->template mutable_data>( - ctx.GetPlace()), - /*dBnBiasData=*/d_bias - ->template mutable_data>( - ctx.GetPlace()), + /*dBnScaleData=*/ctx.template Alloc>( + d_scale), + /*dBnBiasData=*/ctx.template Alloc>(d_bias), /*epsilon=*/epsilon, /*savedMean=*/saved_mean_data, /*savedInvVariance=*/saved_var_data, @@ -626,10 +623,8 @@ void BatchNormGradRawKernel(const Context &ctx, H * W * D, epsilon, transformed_d_x.template data(), - d_scale->template mutable_data>( - ctx.GetPlace()), - d_bias->template mutable_data>( - ctx.GetPlace())); + ctx.template Alloc>(d_scale), + ctx.template Alloc>(d_bias)); } else { BNBackward(), - d_scale->template mutable_data>( - ctx.GetPlace()), - d_bias->template mutable_data>( - ctx.GetPlace())); + ctx.template Alloc>(d_scale), + ctx.template Alloc>(d_bias)); } // TODO(wangran16): wait for MIOpen to improve the performance of BN @@ -682,10 +675,8 @@ void BatchNormGradRawKernel(const Context &ctx, ctx.template Alloc(&transformed_d_x), bn_param_desc_, scale.template data>(), - d_scale->template mutable_data>( - ctx.GetPlace()), - d_bias->template mutable_data>( - ctx.GetPlace()), + ctx.template Alloc>(d_scale), + ctx.template Alloc>(d_bias), epsilon, saved_mean_data, saved_var_data)); diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 49b550f51e60e1cf31658f0d50afebf929a54079..74a523f4ecf942422a1f6c5ca9f710dc0e9d4cbf 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -439,11 +439,11 @@ void BatchNormKernel(const Context &ctx, // Run training mode. // obtain running mean and running inv var, and there is no need // to initialize them. - mean_out->mutable_data>(ctx.GetPlace()); - variance_out->mutable_data>(ctx.GetPlace()); + ctx.template Alloc>(mean_out); + ctx.template Alloc>(variance_out); - saved_mean->mutable_data>(ctx.GetPlace()); - saved_variance->mutable_data>(ctx.GetPlace()); + ctx.template Alloc>(saved_mean); + ctx.template Alloc>(saved_variance); if ((N * H * W * D) == 1) { // Only 1 element in normalization dimension, @@ -497,10 +497,10 @@ void BatchNormKernel(const Context &ctx, /*xDesc=*/data_desc_, /*sizeInBytes=*/&reserve_space_size)); - reserve_space_ptr = reserve_space->mutable_data( - ctx.GetPlace(), transformed_x.type(), reserve_space_size); - workspace_ptr = workspace_tensor.mutable_data( - ctx.GetPlace(), transformed_x.type(), workspace_size); + reserve_space->Resize({static_cast(reserve_space_size)}); + reserve_space_ptr = ctx.template Alloc(reserve_space); + workspace_tensor.Resize({static_cast(workspace_size)}); + workspace_ptr = ctx.template Alloc(&workspace_tensor); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx( handle, @@ -518,15 +518,11 @@ void BatchNormKernel(const Context &ctx, scale.template data>(), bias.template data>(), this_factor, - mean_out->template mutable_data>( - ctx.GetPlace()), - variance_out->template mutable_data>( - ctx.GetPlace()), + ctx.template Alloc>(mean_out), + ctx.template Alloc>(variance_out), epsilon, - saved_mean->template mutable_data>( - ctx.GetPlace()), - saved_variance->template mutable_data>( - ctx.GetPlace()), + ctx.template Alloc>(saved_mean), + ctx.template Alloc>(saved_variance), nullptr, workspace_ptr, workspace_size, @@ -621,15 +617,11 @@ void BatchNormKernel(const Context &ctx, scale.template data>(), bias.template data>(), this_factor, - mean_out->template mutable_data>( - ctx.GetPlace()), - variance_out->template mutable_data>( - ctx.GetPlace()), + ctx.template Alloc>(mean_out), + ctx.template Alloc>(variance_out), epsilon, - saved_mean->template mutable_data>( - ctx.GetPlace()), - saved_variance->template mutable_data>( - ctx.GetPlace()))); + ctx.template Alloc>(saved_mean), + ctx.template Alloc>(saved_variance))); #endif } } diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a970348760c18e2c67e9c7b366cdc2f5e18e3abd --- /dev/null +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -0,0 +1,258 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_grad_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace phi { + +template +__global__ void InputTypeConvert(const InT* in_ids, + const int64_t K, + OutT* out_ids) { + for (int i = 0; i < K; i++) { + out_ids[i] = static_cast(in_ids[i]); + } +} + +template +__global__ void EmbeddingGrad(T* table, + const T* output, + const IdT* ids, + const int64_t N, + const int64_t K, + const int64_t D) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * gridDim.x; + + while (idy < K) { + auto id = static_cast(ids[idy]); + const T* out = output + idy * D; + T* tab = table + id * D; +#ifdef PADDLE_WITH_CUDA + paddle::platform::VectorizedAtomicAddPerBlock(D, idx, blockDim.x, out, tab); +#else + for (int i = idx; i < D; i += blockDim.x) { + paddle::platform::CudaAtomicAdd(&tab[i], out[i]); + } +#endif + idy += blockDim.y * gridDim.x; + } +} + +template +struct EmbeddingGradCUDAFunctor { + EmbeddingGradCUDAFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + padding_idx_(padding_idx), + weight_grad_(weight_grad) {} + + template + void apply() { + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + { + auto d_output_t = out_grad_; + auto d_table_t = weight_grad_; + + int N = weight_grad_->dims()[0]; + int D = weight_grad_->dims()[1]; + int K = input_.numel(); + + const T* d_output = d_output_t.template data(); + const auto* ids = input_.template data(); + T* d_table = dev_ctx_.template Alloc(d_table_t); + +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream())); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream())); +#endif + + const int gridx = 2 * dev_ctx_.GetSMCount(); + dim3 threads(128, 8); + dim3 grids(gridx, 1); + EmbeddingGrad<<>>( + d_table, d_output, ids, N, K, D); + } + } + + private: + const phi::GPUContext& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + const DenseTensor& out_grad_; + int64_t padding_idx_; + DenseTensor* weight_grad_; +}; + +template +void EmbeddingGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) { + EmbeddingGradCUDAFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +template +struct EmbeddingSparseGradCUDAFunctor { + EmbeddingSparseGradCUDAFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + padding_idx_(padding_idx), + weight_grad_(weight_grad) {} + + template + void apply() { + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + + const auto* ids_data = input_.template data(); + auto* d_table = weight_grad_; + auto* table = &weight_; + auto* d_output = &out_grad_; + int64_t ids_num = input_.numel(); + dim3 threads(128, 8); + dim3 grids(8, 1); + auto stream = dev_ctx_.stream(); + paddle::framework::Vector new_rows; + new_rows.resize(ids_num); + auto gpu_place = dev_ctx_.GetPlace(); + + paddle::framework::MixVector mixv_new_rows(&new_rows); + if (!std::is_same::value) { + InputTypeConvert<<>>( + ids_data, ids_num, mixv_new_rows.MutableData(gpu_place)); + } else { + paddle::memory::Copy(gpu_place, + mixv_new_rows.CUDAMutableData(gpu_place), + gpu_place, + ids_data, + ids_num * sizeof(int64_t), + stream); + } + + mixv_new_rows.CopyToCPU(); + d_table->set_rows(new_rows); + + auto* d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table->dims()[1]}); + dev_ctx_.template Alloc(d_table_value); + + auto* d_table_data = d_table_value->template data(); + auto* d_output_data = d_output->template data(); + auto d_output_dims = d_output->dims(); + auto d_output_dims_2d = + phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1); + PADDLE_ENFORCE_EQ(d_table_value->dims(), + d_output_dims_2d, + phi::errors::InvalidArgument( + "ShapeError: The shape of lookup_table@Grad and " + "output@Grad should be same. " + "But received lookup_table@Grad's shape = [%s], " + "output@Grad's shape = [%s].", + d_table_value->dims(), + d_output_dims_2d)); + paddle::memory::Copy(gpu_place, + d_table_data, + gpu_place, + d_output_data, + d_output->numel() * sizeof(T), + stream); + } + + private: + const phi::GPUContext& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + const DenseTensor& out_grad_; + int64_t padding_idx_; + SelectedRows* weight_grad_; +}; + +template +void EmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) { + EmbeddingSparseGradCUDAFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(embedding_grad, + GPU, + ALL_LAYOUT, + phi::EmbeddingGradKernel, + float, + double, + phi::dtype::float16) {} + +PD_REGISTER_KERNEL(embedding_sparse_grad, + GPU, + ALL_LAYOUT, + phi::EmbeddingSparseGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..7f3a31ba544d88534d8a606fba53e017a155023c --- /dev/null +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -0,0 +1,126 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +__global__ void EmbeddingFW(T *output, + const T *table, + const IdT *ids, + const int64_t N, + const int64_t K, + const int64_t D, + const int64_t padding_idx) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * gridDim.x; + + while (idy < K) { + auto id = static_cast(ids[idy]); + T *out = output + idy * D; + const T *tab = table + id * D; + for (int i = idx; i < D; i += blockDim.x) { + if (PaddingFlag) { + if (id == padding_idx) + out[i] = static_cast(0); + else + out[i] = tab[i]; + } else { + out[i] = tab[i]; + } + } + idy += blockDim.y * gridDim.x; + } +} + +template +struct EmbeddingCUDAFunctor { + EmbeddingCUDAFunctor(const Context &dev_ctx, + const DenseTensor &input, + const DenseTensor &weight, + int64_t padding_idx, + DenseTensor *out) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_(out), + padding_idx_(padding_idx) {} + + template + void apply() { + size_t N = weight_.dims()[0]; + size_t D = weight_.dims()[1]; + size_t K = input_.numel(); + + const int gridx = 2 * dev_ctx_.GetSMCount(); + dim3 threads(256, 4); + dim3 grids(gridx, 1); + + const T *table = weight_.template data(); + const IdT *ids = input_.template data(); + auto *output = dev_ctx_.template Alloc(out_); + auto stream = dev_ctx_.stream(); + + if (padding_idx_ == -1) { + EmbeddingFW<<>>( + output, table, ids, N, K, D, padding_idx_); + } else { + EmbeddingFW<<>>( + output, table, ids, N, K, D, padding_idx_); + } + } + + private: + const phi::GPUContext &dev_ctx_; + const DenseTensor &input_; + const DenseTensor &weight_; + DenseTensor *out_; + int64_t padding_idx_; +}; + +template +void EmbeddingKernel(const Context &ctx, + const DenseTensor &input, + const DenseTensor &weight, + int64_t padding_idx, + DenseTensor *out) { + EmbeddingCUDAFunctor functor( + ctx, input, weight, padding_idx, out); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(embedding, + GPU, + ALL_LAYOUT, + phi::EmbeddingKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/reverse_kernel.cu.cc b/paddle/phi/kernels/gpu/reverse_kernel.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..f11eaa11bcdb192411c143e04169c1f1a24d7bf1 --- /dev/null +++ b/paddle/phi/kernels/gpu/reverse_kernel.cu.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reverse_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/reverse_kernel_impl.h" + +PD_REGISTER_KERNEL(reverse, + GPU, + ALL_LAYOUT, + phi::ReverseKernel, + int, + uint8_t, + int64_t, + bool, + float, + double) {} diff --git a/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu index b4a6fe337c8d21e37beb0d6e5219e1a5edf1f9e8..9c5e77d5fd84661cdcc53dffc8f92a954df81041 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu @@ -71,15 +71,15 @@ void ConvCudnnGradGradKernel( auto dW = filter_grad; auto dX = input_grad; if (ddO) { - ddO->mutable_data(ctx.GetPlace()); + ctx.template Alloc(ddO); phi::funcs::SetConstant set_zero; set_zero(ctx, ddO, static_cast(0)); } if (dW) { - dW->mutable_data(ctx.GetPlace()); + ctx.template Alloc(dW); } if (dX) { - dX->mutable_data(ctx.GetPlace()); + ctx.template Alloc(dX); } // const T* x = X->data(); @@ -131,7 +131,7 @@ void ConvCudnnGradGradKernel( } if (dX) { ResizeToChannelFirst(ctx, dX, &transformed_dX_channel); - transformed_dX_channel.mutable_data(ctx.GetPlace()); + ctx.template Alloc(&transformed_dX_channel); } } else { @@ -186,13 +186,13 @@ void ConvCudnnGradGradKernel( transformed_ddX.Resize(new_input_shape); transformed_dX.Resize(new_input_shape); - transformed_X.mutable_data(ctx.GetPlace()); + ctx.template Alloc(&transformed_X); if (ddX) { - transformed_ddX.mutable_data(ctx.GetPlace()); + ctx.template Alloc(&transformed_ddX); } if (dX) { - transformed_dX.mutable_data(ctx.GetPlace()); + ctx.template Alloc(&transformed_dX); } // pad for input diff --git a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu index 64148e902fdb2123aa3f81846999b5d90f356cd6..a99a1e5f9471ed8cf2513c4690630ff24b00c284 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu @@ -58,10 +58,10 @@ void ConvCudnnGradKernel(const Context& ctx, DenseTensor* input_grad, DenseTensor* filter_grad) { if (input_grad) { - input_grad->mutable_data(ctx.GetPlace()); + ctx.template Alloc(input_grad); } if (filter_grad) { - filter_grad->mutable_data(ctx.GetPlace()); + ctx.template Alloc(filter_grad); } std::vector dilations = dilations_t; @@ -204,12 +204,12 @@ void ConvCudnnGradKernel(const Context& ctx, } DDim new_input_shape(make_ddim(new_input_shape_vec)); transformed_input.Resize(new_input_shape); - transformed_input.mutable_data(ctx.GetPlace()); + ctx.template Alloc(&transformed_input); transformed_input_grad.Resize(new_input_shape); if (input_grad) { - transformed_input_grad.mutable_data(ctx.GetPlace()); + ctx.template Alloc(&transformed_input_grad); } // pad for input const int rank = transformed_input_channel.dims().size(); @@ -427,7 +427,7 @@ void ConvCudnnGradKernel(const Context& ctx, if (use_addto) { DenseTensor temp_tensor(transformed_input_grad.type()); temp_tensor.Resize(transformed_input_grad.dims()); - T* temp_tensor_data = temp_tensor.mutable_data(ctx.GetPlace()); + T* temp_tensor_data = ctx.template Alloc(&temp_tensor); workspace_handle.RunFunc( [&](void* cudnn_workspace_ptr) { PADDLE_ENFORCE_GPU_SUCCESS( @@ -513,7 +513,7 @@ void ConvCudnnGradKernel(const Context& ctx, axes[i] = i; } - transformed_input_grad_channel.mutable_data(ctx.GetPlace()); + ctx.template Alloc(&transformed_input_grad_channel); if (transformed_input_channel.dims().size() == 4) { paddle::operators::RemovePaddingSlice( ctx, diff --git a/paddle/phi/kernels/gpudnn/conv_kernel.cu b/paddle/phi/kernels/gpudnn/conv_kernel.cu index 931b6d68845e27297784603c2427178eae6b6f7d..c2970cc8cde75169602de5eec9f0e1424b71a701 100644 --- a/paddle/phi/kernels/gpudnn/conv_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_kernel.cu @@ -54,7 +54,7 @@ void ConvCudnnKernel(const Context& ctx, int workspace_size_MB, bool exhaustive_search_t, DenseTensor* output) { - output->mutable_data(ctx.GetPlace()); + ctx.template Alloc(output); std::vector paddings = paddings_t; std::vector dilations = dilations_t; @@ -170,7 +170,7 @@ void ConvCudnnKernel(const Context& ctx, } DDim new_input_shape(make_ddim(new_input_shape_vec)); transformed_input.Resize(new_input_shape); - transformed_input.mutable_data(ctx.GetPlace()); + ctx.template Alloc(&transformed_input); const int rank = transformed_input_channel.dims().size(); T pad_value(0.0); diff --git a/paddle/phi/kernels/impl/conv_grad_grad_kernel_impl.h b/paddle/phi/kernels/impl/conv_grad_grad_kernel_impl.h index fbcebf371a61bd3d652888b5eaad56185499726b..bc0ed44e17a3346db42f2f858caceabb9d5351b7 100644 --- a/paddle/phi/kernels/impl/conv_grad_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/conv_grad_grad_kernel_impl.h @@ -129,7 +129,7 @@ void ConvGradGradKernel(const Context& dev_ctx, DenseTensor col_matrix; if (is_expand) { col.Resize(col_shape); - col.mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(&col); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } @@ -143,7 +143,7 @@ void ConvGradGradKernel(const Context& dev_ctx, if (dX && ddW_in) { Tensor ddW; ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape); - dX->mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(dX); DenseTensor transformed_dX(dX->type()); @@ -201,7 +201,7 @@ void ConvGradGradKernel(const Context& dev_ctx, // oH, oW) // dw convolution double grad: im2col(vol2col) + gemm if (dW && ddX) { - dW->mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(dW); set_zero(dev_ctx, dW, static_cast(0)); DenseTensor dW_arr = *dW; dW_arr.Resize(filter_matrix_shape); @@ -244,7 +244,7 @@ void ConvGradGradKernel(const Context& dev_ctx, // w/ddw(Cout, Cin, kh, kw) // ddy convolution double grad: im2col(vol2col) + gemm if (ddY) { - ddY->mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(ddY); DenseTensor transformed_ddY(ddY->type()); if (channel_last) { diff --git a/paddle/phi/kernels/impl/conv_grad_kernel_impl.h b/paddle/phi/kernels/impl/conv_grad_kernel_impl.h index f1971aca800b59171a2e741dbebce6d8adaf7899..2deebb996a057a84bd5343be76969ce3a12e1aa1 100644 --- a/paddle/phi/kernels/impl/conv_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/conv_grad_kernel_impl.h @@ -128,7 +128,7 @@ void ConvGradKernel(const Context& dev_ctx, DenseTensor col_matrix; if (is_expand) { col.Resize(col_shape); - col.mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(&col); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } @@ -137,7 +137,7 @@ void ConvGradKernel(const Context& dev_ctx, auto blas = phi::funcs::GetBlas(dev_ctx); if (input_grad) { - input_grad->mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(input_grad); DenseTensor transformed_input_grad(input_grad->type()); if (channel_last) { ResizeToChannelFirst( @@ -203,7 +203,7 @@ void ConvGradKernel(const Context& dev_ctx, } if (filter_grad) { - filter_grad->mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(filter_grad); Tensor filter_grad_ = *filter_grad; filter_grad_.Resize(filter_matrix_shape); set_zero(dev_ctx, filter_grad, static_cast(0)); diff --git a/paddle/phi/kernels/impl/conv_kernel_impl.h b/paddle/phi/kernels/impl/conv_kernel_impl.h index 1945468f02551b8e348687ae578c9f23a038b8ca..2ef2ed8af2809c453db4f5a8c20ed4e004bf64be 100644 --- a/paddle/phi/kernels/impl/conv_kernel_impl.h +++ b/paddle/phi/kernels/impl/conv_kernel_impl.h @@ -44,7 +44,7 @@ void ConvKernel(const Context& dev_ctx, // The filter will be reshaped in the calculations, // so here use an assignment operation, // that avoids modifying the variable in the Scope. - output->mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(output); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); @@ -115,7 +115,7 @@ void ConvKernel(const Context& dev_ctx, if (is_expand) { // col = context.AllocateTmpTensor(col_shape, dev_ctx); col.Resize(col_shape); - col.mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(&col); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } diff --git a/paddle/phi/kernels/impl/reverse_kernel_impl.h b/paddle/phi/kernels/impl/reverse_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..acdd46a086583f1b48b95977e047a18fa026d38b --- /dev/null +++ b/paddle/phi/kernels/impl/reverse_kernel_impl.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/reverse_kernel.h" + +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +struct ReverseFunctor { + void operator()(const Context& dev_ctx, + const DenseTensor& in, + DenseTensor* out, + const std::vector& axis) { + Eigen::DSizes reverse_axis; + for (int i = 0; i < Rank; ++i) { + reverse_axis[i] = false; + } + for (int a : axis) { + if (a >= 0) { + reverse_axis[a] = true; + } else { + reverse_axis[Rank + a] = true; + } + } + + auto in_eigen = EigenTensor::From(in); + auto out_eigen = EigenTensor::From(*out); + auto& dev = *dev_ctx.eigen_device(); + + funcs::EigenReverse, T, Rank>::Eval( + dev, out_eigen, in_eigen, reverse_axis); + } +}; + +template +void ReverseKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { + dev_ctx.template Alloc(out); + int rank = x.dims().size(); + + switch (rank) { + case 1: + ReverseFunctor functor1; + functor1(dev_ctx, x, out, axis); + break; + case 2: + ReverseFunctor functor2; + functor2(dev_ctx, x, out, axis); + break; + case 3: + ReverseFunctor functor3; + functor3(dev_ctx, x, out, axis); + break; + case 4: + ReverseFunctor functor4; + functor4(dev_ctx, x, out, axis); + break; + case 5: + ReverseFunctor functor5; + functor5(dev_ctx, x, out, axis); + break; + case 6: + ReverseFunctor functor6; + functor6(dev_ctx, x, out, axis); + break; + default: + PADDLE_THROW(phi::errors::OutOfRange( + "The reserve operator does not support input tensors" + "whose ranks are greater than 6.")); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/reverse_kernel.cc b/paddle/phi/kernels/reverse_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c6c2781a07bf6af06707c6fe4bcc884b9454c8c4 --- /dev/null +++ b/paddle/phi/kernels/reverse_kernel.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reverse_kernel.h" + +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" + +namespace phi { + +template +void ReverseArrayKernel(const Context& dev_ctx, + const std::vector& x, + const std::vector& axis, + std::vector out) { + PADDLE_ENFORCE_EQ( + x.size(), + out.size(), + phi::errors::InvalidArgument("The input size(%d) and output size(%d) of " + "ReverseArrayKernel is different.", + x.size(), + out.size())); + for (size_t offset = 0; offset < x.size(); ++offset) { + auto* x_tensor = x.at(offset); + PADDLE_ENFORCE_GT( + x_tensor->memory_size(), + 0, + phi::errors::PreconditionNotMet( + "The input LoDTensorArray X[%d] holds no memory.", offset)); + auto out_offset = x.size() - offset - 1; + auto* out_tensor = out.at(out_offset); + + out_tensor->set_lod(x_tensor->lod()); + phi::Copy( + dev_ctx, *x_tensor, dev_ctx.GetPlace(), false, out_tensor); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(reverse_array, + CPU, + ALL_LAYOUT, + phi::ReverseArrayKernel, + int, + uint8_t, + int64_t, + bool, + float, + double) {} +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(reverse_array, + GPU, + ALL_LAYOUT, + phi::ReverseArrayKernel, + int, + uint8_t, + int64_t, + bool, + float, + double) {} +#endif diff --git a/paddle/phi/kernels/reverse_kernel.h b/paddle/phi/kernels/reverse_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..2b81f4018c25d896745637f032c25dbe5551ef26 --- /dev/null +++ b/paddle/phi/kernels/reverse_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ReverseKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out); + +template +void ReverseArrayKernel(const Context& dev_ctx, + const std::vector& x, + const std::vector& axis, + std::vector out); + +} // namespace phi diff --git a/paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h b/paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..772268c2cc3889db6c328fa99425dc6996320050 --- /dev/null +++ b/paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void SparseWeightEmbeddingGradKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad); + +template +void SparseWeightEmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/sparse_weight_embedding_kernel.h b/paddle/phi/kernels/sparse_weight_embedding_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c7392b691aa0fa4f0bc28e35cc29bb6aa902c34f --- /dev/null +++ b/paddle/phi/kernels/sparse_weight_embedding_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void SparseWeightEmbeddingKernel(const Context& ctx, + const DenseTensor& inputx, + const SelectedRows& weight, + int64_t padding_idx, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/conv2d_sig.cc b/paddle/phi/ops/compat/conv2d_sig.cc index a755fdb19ec4b86d4b5265c1d6bce5eecdb9b5b3..67b99f1dd619c09b1a497b547a82287650b30211 100644 --- a/paddle/phi/ops/compat/conv2d_sig.cc +++ b/paddle/phi/ops/compat/conv2d_sig.cc @@ -17,18 +17,31 @@ namespace phi { KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("conv2d", - {"Input", "Filter"}, - {"strides", - "paddings", - "padding_algorithm", - "groups", - "dilations", - "data_format", - "use_addto", - "workspace_size_MB", - "exhaustive_search"}, - {"Output"}); + if (!ctx.HasAttr("use_addto") || !ctx.HasAttr("workspace_size_MB") || + !ctx.HasAttr("exhaustive_search")) { + return KernelSignature("conv2d_infer", + {"Input", "Filter"}, + {"strides", + "paddings", + "padding_algorithm", + "groups", + "dilations", + "data_format"}, + {"Output"}); + } else { + return KernelSignature("conv2d", + {"Input", "Filter"}, + {"strides", + "paddings", + "padding_algorithm", + "groups", + "dilations", + "data_format", + "use_addto", + "workspace_size_MB", + "exhaustive_search"}, + {"Output"}); + } } KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { diff --git a/paddle/phi/ops/compat/embedding_sig.cc b/paddle/phi/ops/compat/embedding_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..b79a381dcecc7943d0e82dbf122ece783cc33791 --- /dev/null +++ b/paddle/phi/ops/compat/embedding_sig.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("W")) { + return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); + } else { + return KernelSignature( + "sparse_weight_embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); + } +} + +KernelSignature EmbeddingGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("W")) { + if ((paddle::any_cast(ctx.Attr("is_sparse"))) == true) { + return KernelSignature("embedding_sparse_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } else { + return KernelSignature("embedding_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } + } else { + if ((paddle::any_cast(ctx.Attr("is_sparse"))) == true) { + return KernelSignature("sparse_weight_embedding_sparse_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } else { + return KernelSignature("sparse_weight_embedding_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } + } +} + +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2, embedding); +PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad, embedding_grad); + +PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2, phi::EmbeddingOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2_grad, + phi::EmbeddingGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/reverse_sig.cc b/paddle/phi/ops/compat/reverse_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b70893fa7877e951cdaba5584e0ff6b31987a9d --- /dev/null +++ b/paddle/phi/ops/compat/reverse_sig.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature ReverseOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorVectorInput("X")) { + return KernelSignature("reverse_array", {"X"}, {"axis"}, {"Out"}); + } else { + return KernelSignature("reverse", {"X"}, {"axis"}, {"Out"}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(reverse, phi::ReverseOpArgumentMapping); diff --git a/paddle/scripts/infrt_build.sh b/paddle/scripts/infrt_build.sh index 1b259023f94df7279066533bb6c182a644b4e9c2..37e19b49f1cd03dc08dadd977358118a3190289c 100755 --- a/paddle/scripts/infrt_build.sh +++ b/paddle/scripts/infrt_build.sh @@ -114,6 +114,7 @@ function create_fake_models() { python3 -m pip install *whl cd ${PADDLE_ROOT}/build python3 ${PADDLE_ROOT}/tools/infrt/fake_models/multi_fc.py + python3 ${PADDLE_ROOT}/paddle/infrt/tests/model/linear.py } function test_infrt() { diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index 1a8cc77e4def59ca6bd1b01b903c4a96a4238b15..d1efe0afeaad09f7c032e4ed692f4eed330d08b5 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -269,7 +269,6 @@ class Tracer(core.Tracer): if framework._in_eager_mode(): # inputs : {"sum": [tensor], ...} # outputs : {"sum": [tensor], ...} - if type in final_state_name_mapping.keys(): final_state_type = final_state_name_mapping[type][ "final_op_name"] diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d247fdb53ea6dbbbe8fdc4a03369c4c927ef664c..221aba9a882e5d92a463fb6cc027a6915e6ca82b 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8730,8 +8730,8 @@ def scatter_nd_add(ref, index, updates, name=None): """ if in_dygraph_mode(): - if _in_eager_mode(): - return _C_ops.final_state_scatter_nd_add(ref, index, updates) + #if _in_eager_mode(): + #return _C_ops.final_state_scatter_nd_add(ref, index, updates) op = getattr(_C_ops, 'scatter_nd_add') return op(ref, index, updates) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 73f73ad16399d8a502546c167a796f927d702635..328301d4be8e9410375c7ec937e78aa3e0e9bb06 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -698,7 +698,10 @@ class OpTest(unittest.TestCase): + str(np_dyg) + "\n" + "But Got" + str(np_api) + " in class " + self.__class__.__name__) - def _calc_python_api_output(self, place): + def _calc_python_api_output(self, place, egr_inps=None, egr_oups=None): + """ set egr_inps and egr_oups = None if you want to create it by yourself. + """ + def prepare_python_api_arguments(api, op_proto_ins, op_proto_attrs, kernel_sig): """ map from `op proto inputs and attrs` to `api input list and api attrs dict` @@ -753,10 +756,15 @@ class OpTest(unittest.TestCase): def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): if not isinstance(ret_tuple, (tuple, list)): ret_tuple = [ret_tuple] - assert len(output_sig) == len( - ret_tuple), "expect %d outputs, but get %d outputs" % ( - len(output_sig), len(ret_tuple)) - return {a: b for a, b in zip(output_sig, ret_tuple)} + if len(output_sig) == len(ret_tuple): + # [assumption]: we assume {"Out": [Tensor]} + return {a: [b] for a, b in zip(output_sig, ret_tuple)} + else: + # [assumption]: return multi-Tensor in a single output. such as paddle.split() + assert len( + output_sig + ) == 1, "Don't support multi-output with multi-tensor output." + return {output_sig[0]: ret_tuple} def assumption_assert_and_transform(args, inp_num): """ @@ -775,6 +783,18 @@ class OpTest(unittest.TestCase): ] + args[inp_num:] return args + def _get_kernel_signature(eager_tensor_inputs, eager_tensor_outputs, + attrs_outputs): + try: + kernel_sig = _dygraph_tracer()._get_kernel_signature( + self.op_type, eager_tensor_inputs, eager_tensor_outputs, + attrs_outputs) + except RuntimeError as re: + """ we think the kernel_sig is missing. + """ + kernel_sig = None + return kernel_sig + def cal_python_api(python_api, args, kernel_sig): inputs_sig, attrs_sig, outputs_sig = kernel_sig args = assumption_assert_and_transform(args, len(inputs_sig)) @@ -785,10 +805,10 @@ class OpTest(unittest.TestCase): block = fluid.default_main_program().global_block() op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) # prepare input variable - eager_tensor_inputs = self.append_input_output_for_dygraph( + eager_tensor_inputs = egr_inps if egr_inps else self.append_input_output_for_dygraph( op_proto, self.inputs, True, False, block) # prepare output variable - eager_tensor_outputs = self.append_input_output_for_dygraph( + eager_tensor_outputs = egr_oups if egr_oups else self.append_input_output_for_dygraph( op_proto, self.outputs, False, False, block) # prepare attrbutes @@ -798,13 +818,13 @@ class OpTest(unittest.TestCase): if self.attrs[attrs_name] is not None: attrs_outputs[attrs_name] = self.attrs[attrs_name] - kernel_sig = _dygraph_tracer()._get_kernel_signature( - self.op_type, eager_tensor_inputs, eager_tensor_outputs, - attrs_outputs) - + kernel_sig = _get_kernel_signature( + eager_tensor_inputs, eager_tensor_outputs, attrs_outputs) + if not kernel_sig: + return None assert hasattr( self, "python_api" - ), "Please set the `self.python_api` if you want to compare python api output." + ), "Detect there is KernelSignature for `%s` op, please set the `self.python_api` if you set check_eager = True" % self.op_type args = prepare_python_api_arguments( self.python_api, eager_tensor_inputs, attrs_outputs, kernel_sig) """ we directly return the cal_python_api value because the value is already tensor. @@ -1285,14 +1305,13 @@ class OpTest(unittest.TestCase): place, no_check_set=no_check_set) if check_eager: + # we only check end2end api when check_eager=True with _test_eager_guard(): - eager_dygraph_outs = self._calc_dygraph_output( - place, no_check_set=no_check_set) - # we only check end2end api when check_eager=True - if hasattr(self, "python_api"): - api_outs = self._calc_python_api_output(place) - self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs, - place) + eager_dygraph_outs = self._calc_python_api_output(place) + if eager_dygraph_outs is None: + # missing KernelSignature, fall back to eager middle output. + eager_dygraph_outs = self._calc_dygraph_output( + place, no_check_set=no_check_set) outs, fetch_list = self._calc_output(place, no_check_set=no_check_set) @@ -1827,7 +1846,7 @@ class OpTest(unittest.TestCase): if check_dygraph: dygraph_grad = self._get_dygraph_grad( inputs_to_check, place, output_names, user_defined_grad_outputs, - no_grad_set) + no_grad_set, False) fp32_grads = [] for grad in dygraph_grad: if grad.dtype == np.uint16: @@ -1843,7 +1862,7 @@ class OpTest(unittest.TestCase): with _test_eager_guard(): eager_dygraph_grad = self._get_dygraph_grad( inputs_to_check, place, output_names, - user_defined_grad_outputs, no_grad_set) + user_defined_grad_outputs, no_grad_set, check_eager) fp32_grads = [] for grad in eager_dygraph_grad: if grad.dtype == np.uint16: @@ -1869,7 +1888,8 @@ class OpTest(unittest.TestCase): place, output_names, user_defined_grad_outputs=None, - no_grad_set=None): + no_grad_set=None, + check_eager=False): with fluid.dygraph.base.guard(place=place): block = fluid.default_main_program().global_block() @@ -1890,11 +1910,16 @@ class OpTest(unittest.TestCase): if self.attrs[attrs_name] is not None: attrs_outputs[attrs_name] = self.attrs[attrs_name] - block.append_op( - type=self.op_type, - inputs=inputs, - outputs=outputs, - attrs=attrs_outputs if hasattr(self, "attrs") else None) + if check_eager: + outputs = self._calc_python_api_output(place, inputs, outputs) + + # if outputs is None, kernel sig is empty or other error is happens. + if not check_eager or outputs is None: + block.append_op( + type=self.op_type, + inputs=inputs, + outputs=outputs, + attrs=attrs_outputs if hasattr(self, "attrs") else None) if self.dtype == np.uint16: cast_inputs = self._find_var_in_dygraph(outputs, diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index cc8fa345e7b67e1a560562e8d9a7797c15e42c6a..6bb6d56497e174e27e7fc5a1d6fb8278a482ef41 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1048,7 +1048,7 @@ class TestAbs(TestActivation): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=False) class TestCeil(TestActivation): diff --git a/python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py b/python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py index eada96665f0b8d8c9b53881de552abe724b9f828..768a8c0635720a9b223e524716a3206372128c05 100644 --- a/python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py +++ b/python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py @@ -29,6 +29,7 @@ from paddle.fluid import Program, program_guard, core paddle.enable_static() +#cholesky_solve implement 1 def cholesky_solution(X, B, upper=True): if upper: A = np.triu(X) @@ -43,6 +44,7 @@ def cholesky_solution(X, B, upper=True): L, B, lower=True)) +#cholesky_solve implement 2 def scipy_cholesky_solution(X, B, upper=True): if upper: umat = np.triu(X) @@ -54,27 +56,29 @@ def scipy_cholesky_solution(X, B, upper=True): return scipy.linalg.cho_solve(K, B) -def boardcast_shape(matA, matB): +#broadcast function used by cholesky_solve +def broadcast_shape(matA, matB): shapeA = matA.shape shapeB = matB.shape - Boardshape = [] + Broadshape = [] for idx in range(len(shapeA) - 2): if shapeA[idx] == shapeB[idx]: - Boardshape.append(shapeA[idx]) + Broadshape.append(shapeA[idx]) continue elif shapeA[idx] == 1 or shapeB[idx] == 1: - Boardshape.append(max(shapeA[idx], shapeB[idx])) + Broadshape.append(max(shapeA[idx], shapeB[idx])) else: raise Exception( - 'shapeA and shapeB should be boardcasted, but got {} and {}'. + 'shapeA and shapeB should be broadcasted, but got {} and {}'. format(shapeA, shapeB)) - bsA = Boardshape + list(shapeA[-2:]) - bsB = Boardshape + list(shapeB[-2:]) + bsA = Broadshape + list(shapeA[-2:]) + bsB = Broadshape + list(shapeB[-2:]) return np.broadcast_to(matA, bsA), np.broadcast_to(matB, bsB) +#cholesky_solve implement in batch def scipy_cholesky_solution_batch(bumat, bB, upper=True): - bumat, bB = boardcast_shape(bumat, bB) + bumat, bB = broadcast_shape(bumat, bB) ushape = bumat.shape bshape = bB.shape bumat = bumat.reshape((-1, ushape[-2], ushape[-1])) @@ -90,18 +94,21 @@ def scipy_cholesky_solution_batch(bumat, bB, upper=True): return np.array(bx).reshape(bshape) -# 2D + 2D , , upper=False +# test condition: shape: 2D + 2D , upper=False +# based on OpTest class class TestCholeskySolveOp(OpTest): """ case 1 """ + #test condition set def config(self): self.y_shape = [15, 15] self.x_shape = [15, 5] self.upper = False - self.dtype = np.float64 + self.dtype = np.float64 #Here cholesky_solve Op only supports float64/float32 type, please check others if Op supports more types. + #get scipy result def set_output(self): umat = self.inputs['Y'] self.output = scipy_cholesky_solution_batch( @@ -125,14 +132,16 @@ class TestCholeskySolveOp(OpTest): self.set_output() self.outputs = {'Out': self.output} + #check Op forward result def test_check_output(self): self.check_output(check_eager=True) + #check Op grad def test_check_grad_normal(self): self.check_grad(['Y'], 'Out', max_relative_error=0.01, check_eager=True) -# 3D(broadcast) + 3D, upper=True +# test condition: 3D(broadcast) + 3D, upper=True class TestCholeskySolveOp3(TestCholeskySolveOp): """ case 3 @@ -145,11 +154,11 @@ class TestCholeskySolveOp3(TestCholeskySolveOp): self.dtype = np.float64 +#API function test class TestCholeskySolveAPI(unittest.TestCase): def setUp(self): np.random.seed(2021) self.place = [paddle.CPUPlace()] - # self.place = [paddle.CUDAPlace(0)] self.dtype = "float64" self.upper = True if core.is_compiled_with_cuda(): @@ -178,10 +187,12 @@ class TestCholeskySolveAPI(unittest.TestCase): fetch_list=[z]) self.assertTrue(np.allclose(fetches[0], z_np)) + #test in static mode def test_static(self): for place in self.place: self.check_static_result(place=place) + #test in dynamic mode def test_dygraph(self): def run(place): paddle.disable_static(place) @@ -200,7 +211,8 @@ class TestCholeskySolveAPI(unittest.TestCase): for idx, place in enumerate(self.place): run(place) - def test_boardcast(self): + #test input with broadcast + def test_broadcast(self): def run(place): paddle.disable_static() x_np = np.random.random([1, 30, 2]).astype(self.dtype) @@ -219,6 +231,7 @@ class TestCholeskySolveAPI(unittest.TestCase): run(place) +#test condition out of bounds class TestCholeskySolveOpError(unittest.TestCase): def test_errors(self): paddle.enable_static() diff --git a/python/paddle/fluid/tests/unittests/test_diag_v2.py b/python/paddle/fluid/tests/unittests/test_diag_v2.py index 74e73ca5cdf5a44828b41b7da68643264e6f1e89..4047ccb8782c877364cf2375a642f48210efeac3 100644 --- a/python/paddle/fluid/tests/unittests/test_diag_v2.py +++ b/python/paddle/fluid/tests/unittests/test_diag_v2.py @@ -43,11 +43,11 @@ class TestDiagV2Op(OpTest): def test_check_output(self): paddle.enable_static() - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): paddle.enable_static() - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=False) def init_config(self): pass diff --git a/python/paddle/fluid/tests/unittests/test_diagonal_op.py b/python/paddle/fluid/tests/unittests/test_diagonal_op.py index b4854aea52a70bd5307193377c85ab229d949e1a..7db5fcb9625a6f89bd7a13512ff44c7ced6474bf 100644 --- a/python/paddle/fluid/tests/unittests/test_diagonal_op.py +++ b/python/paddle/fluid/tests/unittests/test_diagonal_op.py @@ -30,6 +30,7 @@ paddle.enable_static() class TestDiagonalOp(OpTest): def setUp(self): self.op_type = "diagonal" + self.python_api = paddle.diagonal self.init_config() self.outputs = {'Out': self.target} diff --git a/python/paddle/fluid/tests/unittests/test_digamma_op.py b/python/paddle/fluid/tests/unittests/test_digamma_op.py index 3cb31b888f431741bab6098b3cb85c1d3b327e57..4897becf61144fadddb9c8b0efc9dac5f2b4bbf5 100644 --- a/python/paddle/fluid/tests/unittests/test_digamma_op.py +++ b/python/paddle/fluid/tests/unittests/test_digamma_op.py @@ -29,6 +29,7 @@ class TestDigammaOp(OpTest): paddle.enable_static() self.op_type = 'digamma' + self.python_api = paddle.digamma self.init_dtype_type() shape = (5, 32) data = np.random.random(shape).astype(self.dtype) + 1 diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 909e00d1a316a283476c6535ad04d23d5be08ced..4ddfe9d1559de3cd076bc3d03a904dc9c013d44e 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -41,7 +41,8 @@ class TestElementwiseAddOp(OpTest): self.outputs = {'Out': self.out} def check_eager(self): - return (self.use_mkldnn == False and self.axis == -1) + return False + #return (self.use_mkldnn == False and self.axis == -1) def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode diff --git a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py index a7331a353afe822ddae09e2e4034e5e6eeedfc1f..ac2d980f7fd383e274558cbcd2be4a3db3d54747 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py @@ -34,10 +34,10 @@ class TestGatherNdOpWithEmptyIndex(OpTest): } def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=False) class TestGatherNdOpWithIndex1(OpTest): @@ -49,10 +49,10 @@ class TestGatherNdOpWithIndex1(OpTest): self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=False) class TestGatherNdOpWithLowIndex(OpTest): @@ -69,10 +69,10 @@ class TestGatherNdOpWithLowIndex(OpTest): self.outputs = {'Out': xnp[tuple(index.T)]} #[[14, 25, 1], [76, 22, 3]] def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=False) class TestGatherNdOpIndex1(OpTest): @@ -89,10 +89,10 @@ class TestGatherNdOpIndex1(OpTest): self.outputs = {'Out': xnp[tuple(index.T)]} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=False) class TestGatherNdOpWithSameIndexAsX(OpTest): @@ -108,10 +108,10 @@ class TestGatherNdOpWithSameIndexAsX(OpTest): self.outputs = {'Out': xnp[tuple(index.T)]} #[25, 22] def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=False) class TestGatherNdOpWithHighRankSame(OpTest): @@ -128,10 +128,10 @@ class TestGatherNdOpWithHighRankSame(OpTest): self.outputs = {'Out': xnp[tuple(index.T)]} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=False) class TestGatherNdOpWithHighRankDiff(OpTest): @@ -149,10 +149,10 @@ class TestGatherNdOpWithHighRankDiff(OpTest): self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=False) #Test Python API diff --git a/python/paddle/fluid/tests/unittests/test_imperative_partitial_backward.py b/python/paddle/fluid/tests/unittests/test_imperative_partitial_backward.py index 5e3d3c811882ff4733e77dccd491fe462ab835cf..cd31b13083de4cdc2c3dddd344b4efa57a26cf68 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_partitial_backward.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_partitial_backward.py @@ -17,10 +17,11 @@ from __future__ import print_function import unittest import paddle.fluid as fluid import numpy as np +from paddle.fluid.framework import _test_eager_guard class TestImperativePartitialBackward(unittest.TestCase): - def test_partitial_backward(self): + def func_partitial_backward(self): with fluid.dygraph.guard(): x = np.random.randn(2, 4, 5).astype("float32") x = fluid.dygraph.to_variable(x) @@ -49,6 +50,11 @@ class TestImperativePartitialBackward(unittest.TestCase): linear1.clear_gradients() linear2.clear_gradients() + def test_partitial_backward(self): + with _test_eager_guard(): + self.func_partitial_backward() + self.func_partitial_backward() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_index_sample_op.py b/python/paddle/fluid/tests/unittests/test_index_sample_op.py index 4da03c9643fa97e4d1750e257998a658e079f0f5..e2ccb153f406315e4965df222e0689eee646aacb 100644 --- a/python/paddle/fluid/tests/unittests/test_index_sample_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_sample_op.py @@ -40,10 +40,10 @@ class TestIndexSampleOp(OpTest): self.outputs = {'Out': out} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=False) def config(self): """ diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index 65d0e289f81329561eaec73d10aa639689f0e1d3..492f300e3b8481cb2d39266c359b916ada346981 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -105,14 +105,14 @@ class TestMatMulV2Op(OpTest): self.outputs = {'Out': result} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): if core.is_compiled_with_rocm(): self.check_grad( - ['X', 'Y'], 'Out', max_relative_error=1e-2, check_eager=True) + ['X', 'Y'], 'Out', max_relative_error=1e-2, check_eager=False) else: - self.check_grad(['X', 'Y'], 'Out', check_eager=True) + self.check_grad(['X', 'Y'], 'Out', check_eager=False) class TestMatMulOp2(TestMatMulV2Op): @@ -346,7 +346,7 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_output_with_place( - place, atol=atol, check_eager=True) + place, atol=atol, check_eager=False) def test_check_grad(self): place = core.CUDAPlace(0) @@ -355,7 +355,7 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0): place, ['X', 'Y'], 'Out', max_relative_error=max_relative_error, - check_eager=True) + check_eager=False) cls_name = "{0}_{1}".format(parent.__name__, "Fp16") TestMatMulOpFp16Case.__name__ = cls_name @@ -534,7 +534,7 @@ class TestComplexMatMulOp(OpTest): self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out) def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad_normal(self): self.check_grad( @@ -542,7 +542,7 @@ class TestComplexMatMulOp(OpTest): 'Out', user_defined_grads=[self.grad_x, self.grad_y], user_defined_grad_outputs=[self.grad_out], - check_eager=True) + check_eager=False) def test_check_grad_ingore_x(self): self.check_grad( @@ -551,7 +551,7 @@ class TestComplexMatMulOp(OpTest): no_grad_set=set("X"), user_defined_grads=[self.grad_y], user_defined_grad_outputs=[self.grad_out], - check_eager=True) + check_eager=False) def test_check_grad_ingore_y(self): self.check_grad( @@ -560,7 +560,7 @@ class TestComplexMatMulOp(OpTest): no_grad_set=set('Y'), user_defined_grads=[self.grad_x], user_defined_grad_outputs=[self.grad_out], - check_eager=True) + check_eager=False) class TestComplexMatMulOpBroadcast(OpTest): @@ -598,7 +598,7 @@ class TestComplexMatMulOpBroadcast(OpTest): axis=0) def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad_normal(self): self.check_grad( @@ -606,7 +606,7 @@ class TestComplexMatMulOpBroadcast(OpTest): 'Out', user_defined_grads=[self.grad_x, self.grad_y], user_defined_grad_outputs=[self.grad_out], - check_eager=True) + check_eager=False) def test_check_grad_ingore_x(self): self.check_grad( @@ -615,7 +615,7 @@ class TestComplexMatMulOpBroadcast(OpTest): no_grad_set=set("X"), user_defined_grads=[self.grad_y], user_defined_grad_outputs=[self.grad_out], - check_eager=True) + check_eager=False) def test_check_grad_ingore_y(self): self.check_grad( @@ -624,7 +624,7 @@ class TestComplexMatMulOpBroadcast(OpTest): no_grad_set=set('Y'), user_defined_grads=[self.grad_x], user_defined_grad_outputs=[self.grad_out], - check_eager=True) + check_eager=False) class TestMatMulTypePromotion(TestComplexMatMulOp): diff --git a/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py b/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py index ddbee33c35bb1d5b6d1c4ea2b5dec527f4093ce5..d7a27bbddebbaeb1483a295a0e1c4f4d4b8d3b79 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py @@ -77,10 +77,10 @@ class TestScatterNdAddSimpleOp(OpTest): self.outputs = {'Out': expect_np} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out', check_eager=True) + self.check_grad(['X', 'Updates'], 'Out', check_eager=False) class TestScatterNdAddWithEmptyIndex(OpTest): @@ -101,10 +101,10 @@ class TestScatterNdAddWithEmptyIndex(OpTest): self.outputs = {'Out': expect_np} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out', check_eager=True) + self.check_grad(['X', 'Updates'], 'Out', check_eager=False) class TestScatterNdAddWithHighRankSame(OpTest): @@ -128,10 +128,10 @@ class TestScatterNdAddWithHighRankSame(OpTest): self.outputs = {'Out': expect_np} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out', check_eager=True) + self.check_grad(['X', 'Updates'], 'Out', check_eager=False) class TestScatterNdAddWithHighRankDiff(OpTest): @@ -154,10 +154,10 @@ class TestScatterNdAddWithHighRankDiff(OpTest): self.outputs = {'Out': expect_np} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out', check_eager=True) + self.check_grad(['X', 'Updates'], 'Out', check_eager=False) #Test Python API diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index 5cb9b436b5a9251de71d9e698ab6e217f4f95b28..d7f8886dcd3c17d1ed5dada0963d225ed5ae19bb 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -37,10 +37,10 @@ class TestScatterOp(OpTest): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out", check_eager=True) + self.check_grad(["X", "Updates"], "Out", check_eager=False) class TestScatterOp0(OpTest): @@ -57,10 +57,10 @@ class TestScatterOp0(OpTest): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out", check_eager=True) + self.check_grad(["X", "Updates"], "Out", check_eager=False) class TestScatterOp1(OpTest): @@ -80,10 +80,10 @@ class TestScatterOp1(OpTest): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out", check_eager=True) + self.check_grad(["X", "Updates"], "Out", check_eager=False) @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -103,13 +103,13 @@ class TestScatterOp2(OpTest): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3, check_eager=True) + self.check_output_with_place(place, atol=1e-3, check_eager=False) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_eager=True) + place, ['X', 'Updates'], 'Out', check_eager=False) @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -133,13 +133,13 @@ class TestScatterOp3(OpTest): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3, check_eager=True) + self.check_output_with_place(place, atol=1e-3, check_eager=False) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_eager=True) + place, ['X', 'Updates'], 'Out', check_eager=False) class TestScatterOp4(OpTest): @@ -155,10 +155,10 @@ class TestScatterOp4(OpTest): self.outputs = {'Out': output_np} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out', check_eager=True) + self.check_grad(['X', 'Updates'], 'Out', check_eager=False) @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -178,13 +178,13 @@ class TestScatterOp5(OpTest): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3, check_eager=True) + self.check_output_with_place(place, atol=1e-3, check_eager=False) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_eager=True) + place, ['X', 'Updates'], 'Out', check_eager=False) class TestScatterAPI(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py b/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py index aac8b6a99b649176d29224e22c3d3258d96c194e..086527ab554357bc40be958d5dbd0bfa18b37289 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py @@ -20,6 +20,8 @@ import numpy as np import paddle import paddle.nn as nn from paddle.fluid.framework import _test_eager_guard, _in_eager_mode +import paddle.fluid as fluid +import paddle.fluid.core as core class SimpleNet(nn.Layer): @@ -445,8 +447,7 @@ class TestTensorRegisterHook(unittest.TestCase): self.func_multiple_hooks_for_interior_var() self.func_multiple_hooks_for_interior_var() - # TODO(wuweilong): enable this case when DoubleGrad in eager mode is ready - def test_hook_in_double_grad(self): + def func_hook_in_double_grad(self): def double_print_hook(grad): grad = grad * 2 print(grad) @@ -461,10 +462,11 @@ class TestTensorRegisterHook(unittest.TestCase): x.register_hook(double_print_hook) y = x * x - + fluid.set_flags({'FLAGS_retain_grad_for_all_tensor': False}) # Since y = x * x, dx = 2 * x dx = paddle.grad( outputs=[y], inputs=[x], create_graph=True, retain_graph=True)[0] + fluid.set_flags({'FLAGS_retain_grad_for_all_tensor': True}) z = y + dx self.assertTrue(x.grad is None) @@ -475,8 +477,17 @@ class TestTensorRegisterHook(unittest.TestCase): # x.gradient() = 2 * x + 2 = 4.0 # after changed by hook: 8.0 - z.backward() - self.assertTrue(np.array_equal(x.grad.numpy(), np.array([8.]))) + # TODO(wuweilong): enable this case when DoubleGrad in eager mode is ready + if core._in_eager_mode(): + pass + else: + z.backward() + self.assertTrue(np.array_equal(x.grad.numpy(), np.array([8.]))) + + def test_hook_in_double_grad(self): + with _test_eager_guard(): + self.func_hook_in_double_grad() + self.func_hook_in_double_grad() def func_remove_one_hook_multiple_times(self): for device in self.devices: diff --git a/python/paddle/fluid/tests/unittests/test_trunc_op.py b/python/paddle/fluid/tests/unittests/test_trunc_op.py index b70fa04adc13cfd16c43010cce46b31893052927..5bb3e99ee302fc8812635f2905086a44a0b95447 100644 --- a/python/paddle/fluid/tests/unittests/test_trunc_op.py +++ b/python/paddle/fluid/tests/unittests/test_trunc_op.py @@ -29,6 +29,7 @@ paddle.enable_static() class TestTruncOp(OpTest): def setUp(self): self.op_type = "trunc" + self.python_api = paddle.trunc self.dtype = np.float64 np.random.seed(2021) self.inputs = {'X': np.random.random((20, 20)).astype(self.dtype)} diff --git a/python/paddle/fluid/tests/unittests/test_where_op.py b/python/paddle/fluid/tests/unittests/test_where_op.py index 4cfd243ddb46a9c3607bf03d7129c6ee61b3b350..36819e089edbf15eb8871deb1a1e5b28c6b6808d 100644 --- a/python/paddle/fluid/tests/unittests/test_where_op.py +++ b/python/paddle/fluid/tests/unittests/test_where_op.py @@ -35,10 +35,10 @@ class TestWhereOp(OpTest): self.outputs = {'Out': np.where(self.cond, self.x, self.y)} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X', 'Y'], 'Out', check_eager=True) + self.check_grad(['X', 'Y'], 'Out', check_eager=False) def init_config(self): self.x = np.random.uniform((-3), 5, 100).astype('float64') diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py index f210d97362cf062260594dce1112059919f179c4..05a4dfe3c06b61aa79fdda0619715466d81011c4 100644 --- a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -109,7 +109,7 @@ class TestYoloBoxOp(OpTest): self.outputs = {'Boxes': boxes, 'Scores': scores} def test_check_output(self): - self.check_output(check_eager=True) + self.check_output(check_eager=False) def initTestCase(self): self.anchors = [10, 13, 16, 30, 33, 23] diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e1dd5f5e61d96d54873800770a77d37ca36db8fe..da383db0effea95abc1afebbfcadc6f4fc3a3a45 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -243,6 +243,8 @@ def add(x, y, name=None): """ if paddle.in_dynamic_mode(): + #if _in_eager_mode(): + #return _C_ops.final_state_add(x, y) return _C_ops.elementwise_add(x, y) return _elementwise_op(LayerHelper('elementwise_add', **locals())) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 225e22cb90c092c73d4dc9417efb0f3acb44244f..9cd60226a11161269b2cdb02a8660a160380f51d 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -126,7 +126,7 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): axis(int, optional): Axis to compute indices along. The effective range is [-R, R), where R is x.ndim. when axis < 0, it works the same way as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index. - keepdim(bool, optional): Keep the axis that selecting max. The defalut value is False. + keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False. dtype(str|np.dtype, optional): Data type of the output tensor which can be int32, int64. The default value is 'int64', and it will return the int64 indices. @@ -147,12 +147,15 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): [6,9,2,4]]) out1 = paddle.argmax(x) print(out1) # 2 - out2 = paddle.argmax(x, axis=1) + out2 = paddle.argmax(x, axis=0) print(out2) - # [2 3 1] + # [2, 2, 0, 1] out3 = paddle.argmax(x, axis=-1) print(out3) - # [2 3 1] + # [2, 3, 1] + out4 = paddle.argmax(x, axis=0, keepdim=True) + print(out4) + # [[2, 2, 0, 1]] """ if axis is not None and not isinstance(axis, int): raise TypeError( @@ -206,7 +209,7 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): axis(int, optional): Axis to compute indices along. The effective range is [-R, R), where R is x.ndim. when axis < 0, it works the same way as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index. - keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False. + keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False. dtype(str): Data type of the output tensor which can be int32, int64. The default value is 'int64', and it will return the int64 indices. @@ -227,12 +230,15 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): [6,9,2,4]]) out1 = paddle.argmin(x) print(out1) # 4 - out2 = paddle.argmin(x, axis=1) + out2 = paddle.argmin(x, axis=0) print(out2) - # [0 0 2] + # [1, 1, 1, 2] out3 = paddle.argmin(x, axis=-1) print(out3) - # [0 0 2] + # [0, 0, 2] + out4 = paddle.argmin(x, axis=0, keepdim=True) + print(out4) + # [[1, 1, 1, 2]] """ if axis is not None and not isinstance(axis, int): raise TypeError(