提交 7a7beb81 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_some_yaml_config

......@@ -11,7 +11,7 @@ elseif(NEW_RELEASE_ALL)
add_definitions(-DNEW_RELEASE_ALL)
set(paddle_known_gpu_archs "35 50 52 60 61 70 75 80 86")
set(paddle_known_gpu_archs10 "35 50 52 60 61 70 75")
set(paddle_known_gpu_archs11 "35 50 52 60 61 70 75 80")
set(paddle_known_gpu_archs11 "35 50 60 61 70 75 80")
elseif(NEW_RELEASE_PYPI)
message("Using New Release Strategy - Cubin Packge")
add_definitions(-DNEW_RELEASE_PYPI)
......
......@@ -696,14 +696,15 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
FUNCTION_TEMPLATE = """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
// Call grad_api function
VLOG(3) << \"Finally State Running: \" << \"{}\";
auto grad_api_returns = {}::{}({});
{}
}}
"""
node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, grad_api_namespace, bwd_api_name, grad_api_args_str,
returns_str)
grad_node_name, grad_node_name, grad_api_namespace, bwd_api_name,
grad_api_args_str, returns_str)
return node_definition_str
......
......@@ -612,7 +612,9 @@ std::vector<paddle::experimental::Tensor> RunBackward(
for (size_t i = 0; i < edges.size(); i++) {
for (size_t j = 0; j < edges[i].size(); j++) {
const Edge& edge = edges[i][j];
if (!edge.IsInitialized()) {
continue;
}
auto edge_rank = edge.GetEdgeRankInfo();
// Since we make edge has as same rank as bwd outputs, we indexing them
// with
......
......@@ -63,6 +63,8 @@ void GradNodeBase::AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id) {
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo());
} else {
adj_edges_[slot_id].emplace_back();
}
}
}
......@@ -85,6 +87,8 @@ void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) {
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo());
} else {
adj_edges_[slot_id].emplace_back();
}
}
......
......@@ -257,12 +257,22 @@ class Edge {
}
// Currently we use grad_node_ to identify if a edge is initialized.
bool IsInitialized() const { return grad_node_.get(); }
bool IsInitialized() const {
if (!grad_node_) {
return false;
} else {
if (!(grad_node_.get())) {
return false;
} else {
return true;
}
}
}
private:
size_t in_slot_id_;
size_t in_rank_;
std::shared_ptr<GradNodeBase> grad_node_;
std::shared_ptr<GradNodeBase> grad_node_{nullptr};
};
} // namespace egr
......@@ -203,14 +203,6 @@ REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad,
ops::LookupTableV2GradOpNoBufferVarsInferer,
ops::LookupTableV2OpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel<float>,
ops::LookupTableV2Kernel<double>,
ops::LookupTableV2Kernel<paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
lookup_table_v2_grad, ops::LookupTableV2GradKernel<float>,
ops::LookupTableV2GradKernel<double>,
ops::LookupTableV2GradKernel<paddle::platform::bfloat16>);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(lookup_table_v2)
.AddCheckpoint(
......
......@@ -235,13 +235,3 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(lookup_table_v2, ops::LookupTableV2CUDAKernel<float>,
ops::LookupTableV2CUDAKernel<double>,
ops::LookupTableV2CUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(lookup_table_v2_grad,
ops::LookupTableV2GradCUDAKernel<float>,
ops::LookupTableV2GradCUDAKernel<double>,
ops::LookupTableV2GradCUDAKernel<plat::float16>);
......@@ -12,60 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reverse_op.h"
#include <memory>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class ReverseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Reverse");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Reverse");
auto x_var_type = ctx->GetInputsVarType("X")[0];
const auto& axis = ctx->Attrs().Get<std::vector<int>>("axis");
if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
PADDLE_ENFORCE_EQ(
axis.size(), 1,
platform::errors::InvalidArgument(
"The size of axis must be 1 when the Input(X) is LoDTensorArray, "
"but received %d.",
axis.size()));
PADDLE_ENFORCE_EQ(axis[0], 0, platform::errors::InvalidArgument(
"The value of axis should be 1 when "
"the Input(X) is LoDTensorArray, "
"but received %d.",
axis[0]));
// In runtime, shape is determined by RunImpl.
if (!ctx->IsRuntime()) {
const auto& x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
}
return;
}
const auto& x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_NE(axis.empty(), true, platform::errors::InvalidArgument(
"'axis' can not be empty."));
for (int a : axis) {
PADDLE_ENFORCE_LT(a, x_dims.size(),
paddle::platform::errors::OutOfRange(
"The axis must be less than input tensor's rank. "
"but got %d >= %d",
a, x_dims.size()));
PADDLE_ENFORCE_GE(
a, -x_dims.size(),
paddle::platform::errors::OutOfRange(
"The axis must be greater than the negative number of "
"input tensor's rank, but got %d < %d",
a, -x_dims.size()));
}
ctx->SetOutputDim("Out", x_dims);
}
};
class ReverseOpVarTypeInference : public framework::VarTypeInference {
......@@ -134,23 +94,10 @@ class ReverseGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(reverse, ReverseInferShapeFunctor,
PD_INFER_META(phi::ReverseInferMeta));
REGISTER_OPERATOR(reverse, ops::ReverseOp, ops::ReverseOpMaker,
ops::ReverseGradMaker<paddle::framework::OpDesc>,
ops::ReverseGradMaker<paddle::imperative::OpBase>,
ops::ReverseOpVarTypeInference);
ops::ReverseOpVarTypeInference, ReverseInferShapeFunctor);
REGISTER_OPERATOR(reverse_grad, ops::ReverseOp, ops::ReverseOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(
reverse, ops::ReverseKernel<paddle::platform::CPUDeviceContext, int>,
ops::ReverseKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ReverseKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ReverseKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ReverseKernel<paddle::platform::CPUDeviceContext, float>,
ops::ReverseKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
reverse, ops::ReverseKernel<paddle::platform::CUDADeviceContext, int>,
ops::ReverseKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::ReverseKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ReverseKernel<paddle::platform::CUDADeviceContext, bool>,
ops::ReverseKernel<paddle::platform::CUDADeviceContext, float>,
ops::ReverseKernel<paddle::platform::CUDADeviceContext, double>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T, int Rank>
struct ReverseFunctor {
void operator()(const DeviceContext& context, const framework::LoDTensor& in,
framework::LoDTensor* out, const std::vector<int>& axis) {
Eigen::DSizes<bool, Rank> reverse_axis;
for (int i = 0; i < Rank; ++i) {
reverse_axis[i] = false;
}
for (int a : axis) {
if (a >= 0) {
reverse_axis[a] = true;
} else {
reverse_axis[Rank + a] = true;
}
}
auto in_eigen = framework::EigenTensor<T, Rank>::From(in);
auto out_eigen = framework::EigenTensor<T, Rank>::From(*out);
auto& dev = *context.eigen_device();
EigenReverse<std::decay_t<decltype(dev)>, T, Rank>::Eval(
dev, out_eigen, in_eigen, reverse_axis);
}
};
template <typename DeviceContext, typename T>
class ReverseKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x_var = context.InputVar("X");
const auto& axis = context.Attr<std::vector<int>>("axis");
if (x_var->IsType<framework::LoDTensorArray>()) {
auto& x_array = x_var->Get<framework::LoDTensorArray>();
auto* out_array = context.Output<framework::LoDTensorArray>("Out");
out_array->resize(x_array.size());
for (size_t offset = 0; offset < x_array.size(); offset++) {
auto& x_tensor = x_array.at(offset);
PADDLE_ENFORCE_GT(
x_tensor.memory_size(), 0,
platform::errors::PreconditionNotMet(
"The input LoDTensorArray X[%d] holds no memory.", offset));
auto out_offset = x_array.size() - offset - 1;
auto* out_tensor = &out_array->at(out_offset);
out_tensor->set_lod(x_tensor.lod());
paddle::framework::TensorCopy(x_tensor, context.GetPlace(), out_tensor);
}
return;
}
auto* x = context.Input<framework::LoDTensor>("X");
auto* out = context.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(context.GetPlace());
int rank = x->dims().size();
auto& dev_ctx = context.template device_context<DeviceContext>();
switch (rank) {
case 1:
ReverseFunctor<DeviceContext, T, 1> functor1;
functor1(dev_ctx, *x, out, axis);
break;
case 2:
ReverseFunctor<DeviceContext, T, 2> functor2;
functor2(dev_ctx, *x, out, axis);
break;
case 3:
ReverseFunctor<DeviceContext, T, 3> functor3;
functor3(dev_ctx, *x, out, axis);
break;
case 4:
ReverseFunctor<DeviceContext, T, 4> functor4;
functor4(dev_ctx, *x, out, axis);
break;
case 5:
ReverseFunctor<DeviceContext, T, 5> functor5;
functor5(dev_ctx, *x, out, axis);
break;
case 6:
ReverseFunctor<DeviceContext, T, 6> functor6;
functor6(dev_ctx, *x, out, axis);
break;
default:
PADDLE_THROW(paddle::platform::errors::OutOfRange(
"The reserve operator does not support input tensors"
"whose ranks are greater than 6."));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -5,6 +5,10 @@ endif()
add_subdirectory(ir)
add_subdirectory(pass)
add_executable(phi-ir-exec phi_ir_exec.cc)
target_link_libraries(phi-ir-exec infrt)
add_executable(phi-exec phi_exec.cc)
target_link_libraries(phi-exec infrt)
......
......@@ -18,8 +18,8 @@ def PHI_Dialect : Dialect {
def PhiOpTrait : NativeOpTrait<"PhiOpTrait">;
class PHI_Type<string type, list<Trait> traits = []>
: TypeDef<PHI_Dialect, type, !listconcat(traits, [PhiOpTrait, IsolatedFromAbove])> {}
class PHI_Type<string type, list<Trait> traits = [], string baseCppClass = "::mlir::Type">
: TypeDef<PHI_Dialect, type, !listconcat(traits, [PhiOpTrait, IsolatedFromAbove]), baseCppClass> {}
def Allocator : PHI_Type<"Allocator"> {
let mnemonic = "allocator";
......
......@@ -16,6 +16,7 @@
#include "paddle/infrt/dialect/infrt/ir/basic_kernels.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/pd/common/pd_ops_info.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
MLIRModelGenImpl::MLIRModelGenImpl()
: context_(infrt::Global::getMLIRContext()), builder_(context_) {
......@@ -24,6 +25,8 @@ MLIRModelGenImpl::MLIRModelGenImpl()
context_->getOrLoadDialect<infrt::dt::DTDialect>();
context_->getOrLoadDialect<infrt::pd::PaddleDialect>();
context_->getOrLoadDialect<::infrt::InfrtDialect>();
context_->getOrLoadDialect<::infrt::phi::PHIDialect>();
context_->getOrLoadDialect<::infrt::phi::PHIDenseTensorDialect>();
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(context_));
}
......@@ -79,7 +82,7 @@ mlir::FuncOp MLIRModelGenImpl::UpdateModelModule(
llvm::SmallVector<mlir::Type, 4> MLIRModelGenImpl::GetModelInputsType(
const infrt::paddle::framework_proto::ProgramDesc &program) {
llvm::SmallVector<mlir::Type, 4> operandTypes;
operandTypes.push_back(infrt::DenseHostTensorMapType::get(context_));
operandTypes.push_back(infrt::phi::DenseTensorMapType::get(context_));
for (auto &op_desc : main_block_.ops()) {
if (op_desc.type() != "feed") continue;
for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) {
......@@ -180,7 +183,7 @@ void MLIRModelGenImpl::UpdateModelParams(
&precision_);
mlir::Type type_ = infrt::DenseTensorType::get(
context_, infrt::TargetType::CPU, precision_, infrt::LayoutType::ANY);
auto op = builder_.create<infrt::dt::TensorMapGetTensorOp>(
auto op = builder_.create<::infrt::phi::TensorMapGetTensorOp>(
mlir::UnknownLoc::get(context_), type_, map, name);
params_map_.insert(std::pair<std::string, mlir::Value>(
var_desc.name(), op.getOperation()->getResult(0)));
......
......@@ -6,3 +6,4 @@ add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle
DEPENDS infrtopt infrtexec)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir)
// RUN: infrtexec -i %s
module {
func @main_graph(%arg0: !phi.dense_tensor_map, %arg1: !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW> {
%0 = phi_dt.tensor_map_get_tensor(%arg0) {name = "linear_0.w_0"} -> !infrt.dense_tensor<CPU, FP32, NCHW>
%1 = phi_dt.tensor_map_get_tensor(%arg0) {name = "linear_0.b_0"} -> !infrt.dense_tensor<CPU, FP32, NCHW>
%2 = "phi_dt.create_context.cpu"() : () -> !phi.context<CPU>
%5 = "phi_cpu.matmul.float32.any"(%2, %arg1, %0) {trans_x = false, trans_y = false} : (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
%7 = "phi_cpu.add.float32.any"(%2, %5, %1): (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
infrt.return %7 : !infrt.dense_tensor<CPU, FP32, NCHW>
}
func @main() {
%ctx = "phi_dt.create_context.cpu" (): () -> !phi.context<CPU>
%1 = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision<FP32>, layout=#infrt.layout<NCHW>, lod=[1:i64], dims=[16:i64, 784:i64]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
%map = phi_dt.load_combined_params(){model_path="@CMAKE_BINARY_DIR@/linear/linear.pdmodel",params_path="@CMAKE_BINARY_DIR@/linear/linear.pdiparams"}
%2 = infrt.call@main_graph(%map, %1) : (!phi.dense_tensor_map, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
phi_dt.print_tensor (%2 : !infrt.dense_tensor<CPU, FP32, NCHW>)
infrt.return
}
}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# example 1: save layer
import numpy as np
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
BATCH_SIZE = 16
BATCH_NUM = 4
EPOCH_NUM = 4
IMAGE_SIZE = 784
CLASS_NUM = 10
# define a random dataset
class RandomDataset(paddle.io.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
@paddle.jit.to_static
def forward(self, x):
return self._linear(x)
def train(layer, loader, loss_fn, opt):
for epoch_id in range(EPOCH_NUM):
for batch_id, (image, label) in enumerate(loader()):
out = layer(image)
loss = loss_fn(out, label)
loss.backward()
opt.step()
opt.clear_grad()
# 1. train & save model.
# create network
layer = LinearNet()
loss_fn = nn.CrossEntropyLoss()
adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
# create data loader
dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
loader = paddle.io.DataLoader(
dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2)
# train
train(layer, loader, loss_fn, adam)
# save
path = "linear/linear"
paddle.jit.save(layer, path)
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <mutex>
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/macros.h"
#include "paddle/utils/flat_hash_map.h"
......@@ -58,21 +60,22 @@ class DeviceContextPool {
public:
static DeviceContextPool& Instance();
const phi::DeviceContext* Get(const Place& place) const;
const phi::DeviceContext* Get(const Place& place);
phi::DeviceContext* GetMutable(const Place& place);
template <AllocationType T>
const typename DefaultDeviceContextType<T>::TYPE* Get(
const Place& place) const {
const typename DefaultDeviceContextType<T>::TYPE* Get(const Place& place) {
return reinterpret_cast<const typename DefaultDeviceContextType<T>::TYPE*>(
Get(place));
}
private:
DeviceContextPool();
DeviceContextPool() = default;
paddle::flat_hash_map<Place, const phi::DeviceContext*, Place::Hash>
context_map_;
std::mutex mutex_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};
......
......@@ -25,12 +25,17 @@ DeviceContextPool& DeviceContextPool::Instance() {
return g_device_context_pool;
}
const phi::DeviceContext* DeviceContextPool::Get(const Place& place) const {
const phi::DeviceContext* DeviceContextPool::Get(const Place& place) {
auto it = context_map_.find(place);
PADDLE_ENFORCE_NE(
it,
context_map_.end(),
phi::errors::NotFound("The DeviceContext of %s does not exists.", place));
if (it == context_map_.end()) {
// only when we need the specific DeviceContext, get and cache it
auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
{
std::lock_guard<std::mutex> lock(mutex_);
context_map_[place] = dev_ctx;
}
return dev_ctx;
}
return it->second;
}
......@@ -38,28 +43,5 @@ phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) {
return const_cast<phi::DeviceContext*>(Get(place));
}
DeviceContextPool::DeviceContextPool() {
// We need to make sure that the correct value exists
// whenever we get the DeviceContext from DeviceContextPool
const auto& device_contexts =
paddle::platform::DeviceContextPool::Instance().device_contexts();
for (const auto& pair : device_contexts) {
// only get CPU and GPU DeviceContext now, add other DeviceContext type
// later if needed
if (platform::is_cpu_place(pair.first)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
||
platform::is_gpu_place(pair.first)) {
#else
) {
#endif
const phi::DeviceContext* dev_ctx = pair.second.get().get();
VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", "
<< dev_ctx << "}";
context_map_[pair.first] = dev_ctx;
}
}
}
} // namespace experimental
} // namespace paddle
......@@ -23,8 +23,6 @@ limitations under the License. */
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
namespace phi {
namespace detail {
......@@ -469,6 +467,31 @@ void ConvInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}
void ConvInferInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
MetaTensor* out,
MetaConfig config) {
ConvInferMeta(input,
filter,
strides,
paddings,
paddding_algorithm,
groups,
dilations,
data_format,
/*use_addto=*/false,
/*workspace_size_MB=*/512, // useless in infermeta
/*exhaustive_search=*/false,
out,
config);
}
void ConvTransposeInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const std::vector<int>& strides,
......@@ -1670,3 +1693,4 @@ void ValueCompareInferMeta(const MetaTensor& x,
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
PD_REGISTER_INFER_META_FN(conv2d, phi::ConvInferMeta);
PD_REGISTER_INFER_META_FN(conv2d_infer, phi::ConvInferInferMeta);
......@@ -83,6 +83,17 @@ void ConvInferMeta(const MetaTensor& input,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ConvInferInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ConvTransposeInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const std::vector<int>& strides,
......
......@@ -1240,6 +1240,33 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
ReshapeInferMeta(x, shape, out, config);
}
void ReverseInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out) {
PADDLE_ENFORCE_NE(axis.empty(),
true,
phi::errors::InvalidArgument("'axis' can not be empty."));
const auto& x_dims = x.dims();
for (int a : axis) {
PADDLE_ENFORCE_LT(a,
x_dims.size(),
phi::errors::OutOfRange(
"The axis must be less than input tensor's rank. "
"but got %d >= %d",
a,
x_dims.size()));
PADDLE_ENFORCE_GE(
a,
-x_dims.size(),
phi::errors::OutOfRange(
"The axis must be greater than the negative number of "
"input tensor's rank, but got %d < %d",
a,
-x_dims.size()));
}
out->share_meta(x);
}
void RollInferMeta(const MetaTensor& x,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
......
......@@ -198,6 +198,10 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ReverseInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);
void RollInferMeta(const MetaTensor& x,
const ScalarArray& shifts,
const std::vector<int64_t>& axis,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/conv_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
namespace phi {
template <typename T, typename Context>
void ConvInferKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
DenseTensor* out) {
ConvKernel<T, Context>(dev_ctx,
input,
filter,
strides,
paddings,
paddding_algorithm,
groups,
dilations,
data_format,
/*use_addto=*/false,
/*workspace_size_MB=*/paddle::platform::
GetDefaultConvWorkspaceSizeLimitMB(),
/*exhaustive_search=*/false,
out);
}
} // namespace phi
PD_REGISTER_KERNEL(
conv2d_infer, CPU, ALL_LAYOUT, phi::ConvInferKernel, float, double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
conv2d_infer, GPU, ALL_LAYOUT, phi::ConvInferKernel, float, double) {}
#endif
......@@ -64,4 +64,16 @@ void DepthwiseConvKernel(const Context& dev_ctx,
bool fuse_relu,
DenseTensor* out);
template <typename T, typename Context>
void ConvInferKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
struct EmbeddingGradCPUFunctor {
EmbeddingGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
weight_grad_(weight_grad),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
DDim table_dim = weight_.dims();
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_num = static_cast<int64_t>(ids.size());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
{
auto* d_output = &out_grad_;
auto* ids_data = ids.data();
int64_t N = table_dim[0];
int64_t D = table_dim[1];
auto* d_output_data = d_output->template data<T>();
dev_ctx_.template Alloc<T>(weight_grad_);
auto* d_table_data = weight_grad_->data<T>();
memset(d_table_data, 0, weight_grad_->numel() * sizeof(T));
for (int64_t i = 0; i < ids_num; ++i) {
if (padding_idx_ != kNoPadding && ids_data[i] == padding_idx_) {
// the gradient of padding_idx should be 0, already done by memset, so
// do nothing.
} else {
PADDLE_ENFORCE_LT(
ids_data[i],
N,
phi::errors::InvalidArgument(
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
ids_data[i]));
PADDLE_ENFORCE_GE(
ids_data[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
ids_data[i]));
for (int j = 0; j < D; ++j) {
d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
}
}
}
}
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
const DenseTensor& out_grad_;
DenseTensor* weight_grad_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void EmbeddingGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad) {
EmbeddingGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}
template <typename T, typename Context>
struct EmbeddingSparseGradCPUFunctor {
EmbeddingSparseGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
weight_grad_(weight_grad),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
DDim table_dim = weight_.dims();
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_num = static_cast<int64_t>(ids.size());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
auto* d_table = weight_grad_;
auto* d_output = &out_grad_;
d_table->set_rows(ids);
auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]});
dev_ctx_.template Alloc<T>(d_table_value);
d_table->set_height(table_dim[0]);
auto* d_output_data = d_output->template data<T>();
auto* d_table_data = d_table_value->template data<T>();
auto d_output_dims = d_output->dims();
auto d_output_dims_2d =
flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(),
d_output_dims_2d,
phi::errors::InvalidArgument(
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].",
d_table_value->dims(),
d_output_dims_2d));
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
const DenseTensor& out_grad_;
SelectedRows* weight_grad_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void EmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad) {
EmbeddingSparseGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}
} // namespace phi
PD_REGISTER_KERNEL(embedding_grad,
CPU,
ALL_LAYOUT,
phi::EmbeddingGradKernel,
float,
double,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(embedding_sparse_grad,
CPU,
ALL_LAYOUT,
phi::EmbeddingSparseGradKernel,
float,
double,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename T, typename Context>
struct EmbeddingCPUFunctor {
EmbeddingCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
int64_t padding_idx,
DenseTensor* out)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_(out),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_numel = static_cast<int64_t>(ids.size());
int64_t row_number = weight_.dims()[0];
int64_t row_width = weight_.dims()[1];
auto* table = weight_.data<T>();
dev_ctx_.template Alloc<T>(out_);
auto* output = out_->data<T>();
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(
ids[i],
row_number,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number,
ids[i]));
PADDLE_ENFORCE_GE(
ids[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number,
ids[i]));
memcpy(output + i * row_width,
table + ids[i] * row_width,
row_width * sizeof(T));
}
}
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
DenseTensor* out_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void EmbeddingKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
int64_t padding_idx,
DenseTensor* out) {
EmbeddingCPUFunctor<T, Context> functor(ctx, input, weight, padding_idx, out);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}
} // namespace phi
PD_REGISTER_KERNEL(embedding,
CPU,
ALL_LAYOUT,
phi::EmbeddingKernel,
float,
double,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reverse_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/reverse_kernel_impl.h"
PD_REGISTER_KERNEL(reverse,
CPU,
ALL_LAYOUT,
phi::ReverseKernel,
int,
uint8_t,
int64_t,
bool,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename T, typename Context>
struct SparseWeightEmbeddingGradCPUFunctor {
SparseWeightEmbeddingGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
weight_grad_(weight_grad),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
DDim table_dim = weight_.dims();
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_num = static_cast<int64_t>(ids.size());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
{
auto* d_output = &out_grad_;
// auto d_table = weight_grad_;
auto* ids_data = ids.data();
int64_t N = table_dim[0];
int64_t D = table_dim[1];
auto* d_output_data = d_output->template data<T>();
dev_ctx_.template Alloc<T>(weight_grad_);
auto* d_table_data = weight_grad_->data<T>();
memset(d_table_data, 0, weight_grad_->numel() * sizeof(T));
for (int64_t i = 0; i < ids_num; ++i) {
if (padding_idx_ != kNoPadding && ids_data[i] == padding_idx_) {
// the gradient of padding_idx should be 0, already done by memset, so
// do nothing.
} else {
PADDLE_ENFORCE_LT(
ids_data[i],
N,
phi::errors::InvalidArgument(
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
ids_data[i]));
PADDLE_ENFORCE_GE(
ids_data[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
ids_data[i]));
for (int j = 0; j < D; ++j) {
d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
}
}
}
}
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const SelectedRows& weight_;
const DenseTensor& out_grad_;
DenseTensor* weight_grad_;
int64_t padding_idx_;
};
template <typename T, typename Context>
struct SparseWeightEmbeddingSparseGradCPUFunctor {
SparseWeightEmbeddingSparseGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
weight_grad_(weight_grad),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
DDim table_dim = weight_.dims();
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_num = static_cast<int64_t>(ids.size());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
auto* d_table = weight_grad_;
auto* d_output = &out_grad_;
d_table->set_rows(ids);
auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]});
dev_ctx_.template Alloc<T>(d_table_value);
d_table->set_height(table_dim[0]);
auto* d_output_data = d_output->template data<T>();
auto* d_table_data = d_table_value->template data<T>();
auto d_output_dims = d_output->dims();
auto d_output_dims_2d =
phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(),
d_output_dims_2d,
phi::errors::InvalidArgument(
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].",
d_table_value->dims(),
d_output_dims_2d));
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const SelectedRows& weight_;
const DenseTensor& out_grad_;
SelectedRows* weight_grad_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void SparseWeightEmbeddingGradKernel(const Context& ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad) {
SparseWeightEmbeddingGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}
template <typename T, typename Context>
void SparseWeightEmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad) {
SparseWeightEmbeddingSparseGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}
} // namespace phi
PD_REGISTER_KERNEL(sparse_weight_embedding_grad,
CPU,
ALL_LAYOUT,
phi::SparseWeightEmbeddingGradKernel,
float,
double,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(sparse_weight_embedding_sparse_grad,
CPU,
ALL_LAYOUT,
phi::SparseWeightEmbeddingSparseGradKernel,
float,
double,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
template <typename T, typename Context>
struct EmbeddingCPUSparseFunctor {
EmbeddingCPUSparseFunctor(const Context& dev_ctx,
const DenseTensor& input,
const SelectedRows& weight,
int64_t padding_idx,
DenseTensor* out)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_(out),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_numel = static_cast<int64_t>(ids.size());
const auto& table_t = weight_;
auto output_t = out_;
int64_t row_width = table_t.value().dims()[1];
const auto* table = table_t.value().template data<T>();
auto* output = dev_ctx_.template Alloc<T>(output_t);
auto input_data_type =
paddle::framework::TransToProtoVarType(table_t.value().dtype());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(
ids[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0. But received %ld",
ids[i]));
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(
id_index,
0,
phi::errors::InvalidArgument(
"the input key should be exists. But received %d.", id_index));
if (input_data_type == paddle::framework::proto::VarType::BF16) {
memcpy(output + i * row_width,
table + id_index * row_width,
row_width * sizeof(T));
} else {
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx_);
blas.VCOPY(
row_width, table + id_index * row_width, output + i * row_width);
}
}
}
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const SelectedRows& weight_;
DenseTensor* out_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void SparseWeightEmbeddingKernel(const Context& ctx,
const DenseTensor& input,
const SelectedRows& weight,
int64_t padding_idx,
DenseTensor* out) {
EmbeddingCPUSparseFunctor<T, Context> functor(
ctx, input, weight, padding_idx, out);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}
} // namespace phi
PD_REGISTER_KERNEL(sparse_weight_embedding,
CPU,
ALL_LAYOUT,
phi::SparseWeightEmbeddingKernel,
float,
double,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
template <typename T, typename Context>
void EmbeddingGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad);
template <typename T, typename Context>
void EmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void EmbeddingKernel(const Context& ctx,
const DenseTensor& inputx,
const DenseTensor& weight,
int64_t padding_idx,
DenseTensor* out);
} // namespace phi
......@@ -36,8 +36,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context,
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
context.template Alloc<T>(transformed_input);
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
......@@ -47,7 +46,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context,
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
context.template Alloc<T>(transformed_input);
} else if (dim == 1) {
transformed_input->Resize(input->dims());
......@@ -55,7 +54,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context,
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
context.template Alloc<T>(transformed_input);
}
}
......@@ -74,7 +73,7 @@ inline void ResizeToChannelLast(const DeviceContext& context,
in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
context.template Alloc<T>(transformed_input);
} else if (dim == 2) {
// input
......@@ -85,7 +84,7 @@ inline void ResizeToChannelLast(const DeviceContext& context,
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
context.template Alloc<T>(transformed_input);
} else if (dim == 1) {
transformed_input->Resize(input->dims());
......@@ -93,7 +92,7 @@ inline void ResizeToChannelLast(const DeviceContext& context,
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
context.template Alloc<T>(transformed_input);
}
}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
constexpr int64_t kNoPadding = -1;
template <typename InT, typename OutT>
static std::vector<OutT> CopyIdsToVector(const DenseTensor &ids) {
auto numel = ids.numel();
const auto *src = ids.data<InT>();
std::vector<OutT> ret(numel);
if (std::is_same<InT, OutT>::value) {
std::memcpy(ret.data(), src, numel * sizeof(InT));
} else {
for (decltype(numel) i = 0; i < numel; ++i) {
ret[i] = src[i];
}
}
return ret;
}
} // namespace phi
......@@ -359,8 +359,8 @@ void BatchNormGradRawKernel(const Context &ctx,
}
if (d_scale && d_bias) {
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
ctx.template Alloc<BatchNormParamType<T>>(d_scale);
ctx.template Alloc<BatchNormParamType<T>>(d_bias);
}
PADDLE_ENFORCE_EQ(
......@@ -569,8 +569,8 @@ void BatchNormGradRawKernel(const Context &ctx,
/*activationDesc=*/nullptr,
/*sizeInBytes=*/&workspace_size));
workspace_ptr = workspace_tensor.mutable_data(
ctx.GetPlace(), transformed_x.type(), workspace_size);
workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
workspace_ptr = ctx.template Alloc<T>(&workspace_tensor);
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationBackwardEx(
......@@ -594,12 +594,9 @@ void BatchNormGradRawKernel(const Context &ctx,
/*dBnScaleBiasDesc=*/bn_param_desc_,
/*bnScaleData=*/scale.template data<BatchNormParamType<T>>(),
/*bnBiasData=*/nullptr,
/*dBnScaleData=*/d_scale
->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
/*dBnBiasData=*/d_bias
->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
/*dBnScaleData=*/ctx.template Alloc<BatchNormParamType<T>>(
d_scale),
/*dBnBiasData=*/ctx.template Alloc<BatchNormParamType<T>>(d_bias),
/*epsilon=*/epsilon,
/*savedMean=*/saved_mean_data,
/*savedInvVariance=*/saved_var_data,
......@@ -626,10 +623,8 @@ void BatchNormGradRawKernel(const Context &ctx,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()));
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
} else {
BNBackward<T,
block,
......@@ -644,10 +639,8 @@ void BatchNormGradRawKernel(const Context &ctx,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()));
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
......@@ -682,10 +675,8 @@ void BatchNormGradRawKernel(const Context &ctx,
ctx.template Alloc<T>(&transformed_d_x),
bn_param_desc_,
scale.template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias),
epsilon,
saved_mean_data,
saved_var_data));
......
......@@ -439,11 +439,11 @@ void BatchNormKernel(const Context &ctx,
// Run training mode.
// obtain running mean and running inv var, and there is no need
// to initialize them.
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
ctx.template Alloc<BatchNormParamType<T>>(mean_out);
ctx.template Alloc<BatchNormParamType<T>>(variance_out);
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
ctx.template Alloc<BatchNormParamType<T>>(saved_mean);
ctx.template Alloc<BatchNormParamType<T>>(saved_variance);
if ((N * H * W * D) == 1) {
// Only 1 element in normalization dimension,
......@@ -497,10 +497,10 @@ void BatchNormKernel(const Context &ctx,
/*xDesc=*/data_desc_,
/*sizeInBytes=*/&reserve_space_size));
reserve_space_ptr = reserve_space->mutable_data(
ctx.GetPlace(), transformed_x.type(), reserve_space_size);
workspace_ptr = workspace_tensor.mutable_data(
ctx.GetPlace(), transformed_x.type(), workspace_size);
reserve_space->Resize({static_cast<int64_t>(reserve_space_size)});
reserve_space_ptr = ctx.template Alloc<T>(reserve_space);
workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
workspace_ptr = ctx.template Alloc<T>(&workspace_tensor);
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
handle,
......@@ -518,15 +518,11 @@ void BatchNormKernel(const Context &ctx,
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
ctx.template Alloc<BatchNormParamType<T>>(mean_out),
ctx.template Alloc<BatchNormParamType<T>>(variance_out),
epsilon,
saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
ctx.template Alloc<BatchNormParamType<T>>(saved_mean),
ctx.template Alloc<BatchNormParamType<T>>(saved_variance),
nullptr,
workspace_ptr,
workspace_size,
......@@ -621,15 +617,11 @@ void BatchNormKernel(const Context &ctx,
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
ctx.template Alloc<BatchNormParamType<T>>(mean_out),
ctx.template Alloc<BatchNormParamType<T>>(variance_out),
epsilon,
saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())));
ctx.template Alloc<BatchNormParamType<T>>(saved_mean),
ctx.template Alloc<BatchNormParamType<T>>(saved_variance)));
#endif
}
}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace phi {
template <typename InT, typename OutT>
__global__ void InputTypeConvert(const InT* in_ids,
const int64_t K,
OutT* out_ids) {
for (int i = 0; i < K; i++) {
out_ids[i] = static_cast<OutT>(in_ids[i]);
}
}
template <typename T, typename IdT>
__global__ void EmbeddingGrad(T* table,
const T* output,
const IdT* ids,
const int64_t N,
const int64_t K,
const int64_t D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * gridDim.x;
while (idy < K) {
auto id = static_cast<int64_t>(ids[idy]);
const T* out = output + idy * D;
T* tab = table + id * D;
#ifdef PADDLE_WITH_CUDA
paddle::platform::VectorizedAtomicAddPerBlock(D, idx, blockDim.x, out, tab);
#else
for (int i = idx; i < D; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
}
#endif
idy += blockDim.y * gridDim.x;
}
}
template <typename T, typename Context>
struct EmbeddingGradCUDAFunctor {
EmbeddingGradCUDAFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
padding_idx_(padding_idx),
weight_grad_(weight_grad) {}
template <typename IdT>
void apply() {
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
{
auto d_output_t = out_grad_;
auto d_table_t = weight_grad_;
int N = weight_grad_->dims()[0];
int D = weight_grad_->dims()[1];
int K = input_.numel();
const T* d_output = d_output_t.template data<T>();
const auto* ids = input_.template data<IdT>();
T* d_table = dev_ctx_.template Alloc<T>(d_table_t);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream()));
#endif
const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(128, 8);
dim3 grids(gridx, 1);
EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D);
}
}
private:
const phi::GPUContext& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
const DenseTensor& out_grad_;
int64_t padding_idx_;
DenseTensor* weight_grad_;
};
template <typename T, typename Context>
void EmbeddingGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad) {
EmbeddingGradCUDAFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}
template <typename T, typename Context>
struct EmbeddingSparseGradCUDAFunctor {
EmbeddingSparseGradCUDAFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
padding_idx_(padding_idx),
weight_grad_(weight_grad) {}
template <typename IdT>
void apply() {
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
const auto* ids_data = input_.template data<IdT>();
auto* d_table = weight_grad_;
auto* table = &weight_;
auto* d_output = &out_grad_;
int64_t ids_num = input_.numel();
dim3 threads(128, 8);
dim3 grids(8, 1);
auto stream = dev_ctx_.stream();
paddle::framework::Vector<int64_t> new_rows;
new_rows.resize(ids_num);
auto gpu_place = dev_ctx_.GetPlace();
paddle::framework::MixVector<int64_t> mixv_new_rows(&new_rows);
if (!std::is_same<IdT, int64_t>::value) {
InputTypeConvert<<<grids, threads, 0, stream>>>(
ids_data, ids_num, mixv_new_rows.MutableData(gpu_place));
} else {
paddle::memory::Copy(gpu_place,
mixv_new_rows.CUDAMutableData(gpu_place),
gpu_place,
ids_data,
ids_num * sizeof(int64_t),
stream);
}
mixv_new_rows.CopyToCPU();
d_table->set_rows(new_rows);
auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table->dims()[1]});
dev_ctx_.template Alloc<T>(d_table_value);
auto* d_table_data = d_table_value->template data<T>();
auto* d_output_data = d_output->template data<T>();
auto d_output_dims = d_output->dims();
auto d_output_dims_2d =
phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(),
d_output_dims_2d,
phi::errors::InvalidArgument(
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].",
d_table_value->dims(),
d_output_dims_2d));
paddle::memory::Copy(gpu_place,
d_table_data,
gpu_place,
d_output_data,
d_output->numel() * sizeof(T),
stream);
}
private:
const phi::GPUContext& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
const DenseTensor& out_grad_;
int64_t padding_idx_;
SelectedRows* weight_grad_;
};
template <typename T, typename Context>
void EmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad) {
EmbeddingSparseGradCUDAFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}
} // namespace phi
PD_REGISTER_KERNEL(embedding_grad,
GPU,
ALL_LAYOUT,
phi::EmbeddingGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(embedding_sparse_grad,
GPU,
ALL_LAYOUT,
phi::EmbeddingSparseGradKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T, typename IdT, bool PaddingFlag>
__global__ void EmbeddingFW(T *output,
const T *table,
const IdT *ids,
const int64_t N,
const int64_t K,
const int64_t D,
const int64_t padding_idx) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * gridDim.x;
while (idy < K) {
auto id = static_cast<int64_t>(ids[idy]);
T *out = output + idy * D;
const T *tab = table + id * D;
for (int i = idx; i < D; i += blockDim.x) {
if (PaddingFlag) {
if (id == padding_idx)
out[i] = static_cast<T>(0);
else
out[i] = tab[i];
} else {
out[i] = tab[i];
}
}
idy += blockDim.y * gridDim.x;
}
}
template <typename T, typename Context>
struct EmbeddingCUDAFunctor {
EmbeddingCUDAFunctor(const Context &dev_ctx,
const DenseTensor &input,
const DenseTensor &weight,
int64_t padding_idx,
DenseTensor *out)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_(out),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
size_t N = weight_.dims()[0];
size_t D = weight_.dims()[1];
size_t K = input_.numel();
const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(256, 4);
dim3 grids(gridx, 1);
const T *table = weight_.template data<T>();
const IdT *ids = input_.template data<IdT>();
auto *output = dev_ctx_.template Alloc<T>(out_);
auto stream = dev_ctx_.stream();
if (padding_idx_ == -1) {
EmbeddingFW<T, IdT, false><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_);
} else {
EmbeddingFW<T, IdT, true><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_);
}
}
private:
const phi::GPUContext &dev_ctx_;
const DenseTensor &input_;
const DenseTensor &weight_;
DenseTensor *out_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void EmbeddingKernel(const Context &ctx,
const DenseTensor &input,
const DenseTensor &weight,
int64_t padding_idx,
DenseTensor *out) {
EmbeddingCUDAFunctor<T, Context> functor(
ctx, input, weight, padding_idx, out);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int32_t>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}
} // namespace phi
PD_REGISTER_KERNEL(embedding,
GPU,
ALL_LAYOUT,
phi::EmbeddingKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reverse_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/reverse_kernel_impl.h"
PD_REGISTER_KERNEL(reverse,
GPU,
ALL_LAYOUT,
phi::ReverseKernel,
int,
uint8_t,
int64_t,
bool,
float,
double) {}
......@@ -71,15 +71,15 @@ void ConvCudnnGradGradKernel(
auto dW = filter_grad;
auto dX = input_grad;
if (ddO) {
ddO->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(ddO);
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, ddO, static_cast<T>(0));
}
if (dW) {
dW->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(dW);
}
if (dX) {
dX->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(dX);
}
// const T* x = X->data<T>();
......@@ -131,7 +131,7 @@ void ConvCudnnGradGradKernel(
}
if (dX) {
ResizeToChannelFirst<Context, T>(ctx, dX, &transformed_dX_channel);
transformed_dX_channel.mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(&transformed_dX_channel);
}
} else {
......@@ -186,13 +186,13 @@ void ConvCudnnGradGradKernel(
transformed_ddX.Resize(new_input_shape);
transformed_dX.Resize(new_input_shape);
transformed_X.mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(&transformed_X);
if (ddX) {
transformed_ddX.mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(&transformed_ddX);
}
if (dX) {
transformed_dX.mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(&transformed_dX);
}
// pad for input
......
......@@ -58,10 +58,10 @@ void ConvCudnnGradKernel(const Context& ctx,
DenseTensor* input_grad,
DenseTensor* filter_grad) {
if (input_grad) {
input_grad->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(input_grad);
}
if (filter_grad) {
filter_grad->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(filter_grad);
}
std::vector<int> dilations = dilations_t;
......@@ -204,12 +204,12 @@ void ConvCudnnGradKernel(const Context& ctx,
}
DDim new_input_shape(make_ddim(new_input_shape_vec));
transformed_input.Resize(new_input_shape);
transformed_input.mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(&transformed_input);
transformed_input_grad.Resize(new_input_shape);
if (input_grad) {
transformed_input_grad.mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(&transformed_input_grad);
}
// pad for input
const int rank = transformed_input_channel.dims().size();
......@@ -427,7 +427,7 @@ void ConvCudnnGradKernel(const Context& ctx,
if (use_addto) {
DenseTensor temp_tensor(transformed_input_grad.type());
temp_tensor.Resize(transformed_input_grad.dims());
T* temp_tensor_data = temp_tensor.mutable_data<T>(ctx.GetPlace());
T* temp_tensor_data = ctx.template Alloc<T>(&temp_tensor);
workspace_handle.RunFunc(
[&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
......@@ -513,7 +513,7 @@ void ConvCudnnGradKernel(const Context& ctx,
axes[i] = i;
}
transformed_input_grad_channel.mutable_data(ctx.GetPlace());
ctx.template Alloc<T>(&transformed_input_grad_channel);
if (transformed_input_channel.dims().size() == 4) {
paddle::operators::RemovePaddingSlice<Context, T, 4>(
ctx,
......
......@@ -54,7 +54,7 @@ void ConvCudnnKernel(const Context& ctx,
int workspace_size_MB,
bool exhaustive_search_t,
DenseTensor* output) {
output->mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(output);
std::vector<int> paddings = paddings_t;
std::vector<int> dilations = dilations_t;
......@@ -170,7 +170,7 @@ void ConvCudnnKernel(const Context& ctx,
}
DDim new_input_shape(make_ddim(new_input_shape_vec));
transformed_input.Resize(new_input_shape);
transformed_input.mutable_data<T>(ctx.GetPlace());
ctx.template Alloc<T>(&transformed_input);
const int rank = transformed_input_channel.dims().size();
T pad_value(0.0);
......
......@@ -129,7 +129,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
DenseTensor col_matrix;
if (is_expand) {
col.Resize(col_shape);
col.mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(&col);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
......@@ -143,7 +143,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
if (dX && ddW_in) {
Tensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
dX->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(dX);
DenseTensor transformed_dX(dX->type());
......@@ -201,7 +201,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
// oH, oW)
// dw convolution double grad: im2col(vol2col) + gemm
if (dW && ddX) {
dW->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(dW);
set_zero(dev_ctx, dW, static_cast<T>(0));
DenseTensor dW_arr = *dW;
dW_arr.Resize(filter_matrix_shape);
......@@ -244,7 +244,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
// w/ddw(Cout, Cin, kh, kw)
// ddy convolution double grad: im2col(vol2col) + gemm
if (ddY) {
ddY->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(ddY);
DenseTensor transformed_ddY(ddY->type());
if (channel_last) {
......
......@@ -128,7 +128,7 @@ void ConvGradKernel(const Context& dev_ctx,
DenseTensor col_matrix;
if (is_expand) {
col.Resize(col_shape);
col.mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(&col);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
......@@ -137,7 +137,7 @@ void ConvGradKernel(const Context& dev_ctx,
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
if (input_grad) {
input_grad->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(input_grad);
DenseTensor transformed_input_grad(input_grad->type());
if (channel_last) {
ResizeToChannelFirst<Context, T>(
......@@ -203,7 +203,7 @@ void ConvGradKernel(const Context& dev_ctx,
}
if (filter_grad) {
filter_grad->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(filter_grad);
Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
......
......@@ -44,7 +44,7 @@ void ConvKernel(const Context& dev_ctx,
// The filter will be reshaped in the calculations,
// so here use an assignment operation,
// that avoids modifying the variable in the Scope.
output->mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(output);
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
......@@ -115,7 +115,7 @@ void ConvKernel(const Context& dev_ctx,
if (is_expand) {
// col = context.AllocateTmpTensor<T, DeviceContext>(col_shape, dev_ctx);
col.Resize(col_shape);
col.mutable_data<T>(dev_ctx.GetPlace());
dev_ctx.template Alloc<T>(&col);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/reverse_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename Context, typename T, int Rank>
struct ReverseFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor& in,
DenseTensor* out,
const std::vector<int>& axis) {
Eigen::DSizes<bool, Rank> reverse_axis;
for (int i = 0; i < Rank; ++i) {
reverse_axis[i] = false;
}
for (int a : axis) {
if (a >= 0) {
reverse_axis[a] = true;
} else {
reverse_axis[Rank + a] = true;
}
}
auto in_eigen = EigenTensor<T, Rank>::From(in);
auto out_eigen = EigenTensor<T, Rank>::From(*out);
auto& dev = *dev_ctx.eigen_device();
funcs::EigenReverse<std::decay_t<decltype(dev)>, T, Rank>::Eval(
dev, out_eigen, in_eigen, reverse_axis);
}
};
template <typename T, typename Context>
void ReverseKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
int rank = x.dims().size();
switch (rank) {
case 1:
ReverseFunctor<Context, T, 1> functor1;
functor1(dev_ctx, x, out, axis);
break;
case 2:
ReverseFunctor<Context, T, 2> functor2;
functor2(dev_ctx, x, out, axis);
break;
case 3:
ReverseFunctor<Context, T, 3> functor3;
functor3(dev_ctx, x, out, axis);
break;
case 4:
ReverseFunctor<Context, T, 4> functor4;
functor4(dev_ctx, x, out, axis);
break;
case 5:
ReverseFunctor<Context, T, 5> functor5;
functor5(dev_ctx, x, out, axis);
break;
case 6:
ReverseFunctor<Context, T, 6> functor6;
functor6(dev_ctx, x, out, axis);
break;
default:
PADDLE_THROW(phi::errors::OutOfRange(
"The reserve operator does not support input tensors"
"whose ranks are greater than 6."));
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reverse_kernel.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
namespace phi {
template <typename T, typename Context>
void ReverseArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<int>& axis,
std::vector<DenseTensor*> out) {
PADDLE_ENFORCE_EQ(
x.size(),
out.size(),
phi::errors::InvalidArgument("The input size(%d) and output size(%d) of "
"ReverseArrayKernel is different.",
x.size(),
out.size()));
for (size_t offset = 0; offset < x.size(); ++offset) {
auto* x_tensor = x.at(offset);
PADDLE_ENFORCE_GT(
x_tensor->memory_size(),
0,
phi::errors::PreconditionNotMet(
"The input LoDTensorArray X[%d] holds no memory.", offset));
auto out_offset = x.size() - offset - 1;
auto* out_tensor = out.at(out_offset);
out_tensor->set_lod(x_tensor->lod());
phi::Copy<Context>(
dev_ctx, *x_tensor, dev_ctx.GetPlace(), false, out_tensor);
}
}
} // namespace phi
PD_REGISTER_KERNEL(reverse_array,
CPU,
ALL_LAYOUT,
phi::ReverseArrayKernel,
int,
uint8_t,
int64_t,
bool,
float,
double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(reverse_array,
GPU,
ALL_LAYOUT,
phi::ReverseArrayKernel,
int,
uint8_t,
int64_t,
bool,
float,
double) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ReverseKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out);
template <typename T, typename Context>
void ReverseArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<int>& axis,
std::vector<DenseTensor*> out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
template <typename T, typename Context>
void SparseWeightEmbeddingGradKernel(const Context& ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad);
template <typename T, typename Context>
void SparseWeightEmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
template <typename T, typename Context>
void SparseWeightEmbeddingKernel(const Context& ctx,
const DenseTensor& inputx,
const SelectedRows& weight,
int64_t padding_idx,
DenseTensor* out);
} // namespace phi
......@@ -17,6 +17,18 @@
namespace phi {
KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (!ctx.HasAttr("use_addto") || !ctx.HasAttr("workspace_size_MB") ||
!ctx.HasAttr("exhaustive_search")) {
return KernelSignature("conv2d_infer",
{"Input", "Filter"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Output"});
} else {
return KernelSignature("conv2d",
{"Input", "Filter"},
{"strides",
......@@ -29,6 +41,7 @@ KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) {
"workspace_size_MB",
"exhaustive_search"},
{"Output"});
}
}
KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("W")) {
return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"});
} else {
return KernelSignature(
"sparse_weight_embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"});
}
}
KernelSignature EmbeddingGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("W")) {
if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
return KernelSignature("embedding_sparse_grad",
{"Ids", "W", GradVarName("Out")},
{"padding_idx"},
{GradVarName("W")});
} else {
return KernelSignature("embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"padding_idx"},
{GradVarName("W")});
}
} else {
if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
return KernelSignature("sparse_weight_embedding_sparse_grad",
{"Ids", "W", GradVarName("Out")},
{"padding_idx"},
{GradVarName("W")});
} else {
return KernelSignature("sparse_weight_embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"padding_idx"},
{GradVarName("W")});
}
}
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2, embedding);
PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad, embedding_grad);
PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2, phi::EmbeddingOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2_grad,
phi::EmbeddingGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ReverseOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorVectorInput("X")) {
return KernelSignature("reverse_array", {"X"}, {"axis"}, {"Out"});
} else {
return KernelSignature("reverse", {"X"}, {"axis"}, {"Out"});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(reverse, phi::ReverseOpArgumentMapping);
......@@ -114,6 +114,7 @@ function create_fake_models() {
python3 -m pip install *whl
cd ${PADDLE_ROOT}/build
python3 ${PADDLE_ROOT}/tools/infrt/fake_models/multi_fc.py
python3 ${PADDLE_ROOT}/paddle/infrt/tests/model/linear.py
}
function test_infrt() {
......
......@@ -269,7 +269,6 @@ class Tracer(core.Tracer):
if framework._in_eager_mode():
# inputs : {"sum": [tensor], ...}
# outputs : {"sum": [tensor], ...}
if type in final_state_name_mapping.keys():
final_state_type = final_state_name_mapping[type][
"final_op_name"]
......
......@@ -8730,8 +8730,8 @@ def scatter_nd_add(ref, index, updates, name=None):
"""
if in_dygraph_mode():
if _in_eager_mode():
return _C_ops.final_state_scatter_nd_add(ref, index, updates)
#if _in_eager_mode():
#return _C_ops.final_state_scatter_nd_add(ref, index, updates)
op = getattr(_C_ops, 'scatter_nd_add')
return op(ref, index, updates)
......
......@@ -698,7 +698,10 @@ class OpTest(unittest.TestCase):
+ str(np_dyg) + "\n" + "But Got" + str(np_api) + " in class " +
self.__class__.__name__)
def _calc_python_api_output(self, place):
def _calc_python_api_output(self, place, egr_inps=None, egr_oups=None):
""" set egr_inps and egr_oups = None if you want to create it by yourself.
"""
def prepare_python_api_arguments(api, op_proto_ins, op_proto_attrs,
kernel_sig):
""" map from `op proto inputs and attrs` to `api input list and api attrs dict`
......@@ -753,10 +756,15 @@ class OpTest(unittest.TestCase):
def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
if not isinstance(ret_tuple, (tuple, list)):
ret_tuple = [ret_tuple]
assert len(output_sig) == len(
ret_tuple), "expect %d outputs, but get %d outputs" % (
len(output_sig), len(ret_tuple))
return {a: b for a, b in zip(output_sig, ret_tuple)}
if len(output_sig) == len(ret_tuple):
# [assumption]: we assume {"Out": [Tensor]}
return {a: [b] for a, b in zip(output_sig, ret_tuple)}
else:
# [assumption]: return multi-Tensor in a single output. such as paddle.split()
assert len(
output_sig
) == 1, "Don't support multi-output with multi-tensor output."
return {output_sig[0]: ret_tuple}
def assumption_assert_and_transform(args, inp_num):
"""
......@@ -775,6 +783,18 @@ class OpTest(unittest.TestCase):
] + args[inp_num:]
return args
def _get_kernel_signature(eager_tensor_inputs, eager_tensor_outputs,
attrs_outputs):
try:
kernel_sig = _dygraph_tracer()._get_kernel_signature(
self.op_type, eager_tensor_inputs, eager_tensor_outputs,
attrs_outputs)
except RuntimeError as re:
""" we think the kernel_sig is missing.
"""
kernel_sig = None
return kernel_sig
def cal_python_api(python_api, args, kernel_sig):
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = assumption_assert_and_transform(args, len(inputs_sig))
......@@ -785,10 +805,10 @@ class OpTest(unittest.TestCase):
block = fluid.default_main_program().global_block()
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
# prepare input variable
eager_tensor_inputs = self.append_input_output_for_dygraph(
eager_tensor_inputs = egr_inps if egr_inps else self.append_input_output_for_dygraph(
op_proto, self.inputs, True, False, block)
# prepare output variable
eager_tensor_outputs = self.append_input_output_for_dygraph(
eager_tensor_outputs = egr_oups if egr_oups else self.append_input_output_for_dygraph(
op_proto, self.outputs, False, False, block)
# prepare attrbutes
......@@ -798,13 +818,13 @@ class OpTest(unittest.TestCase):
if self.attrs[attrs_name] is not None:
attrs_outputs[attrs_name] = self.attrs[attrs_name]
kernel_sig = _dygraph_tracer()._get_kernel_signature(
self.op_type, eager_tensor_inputs, eager_tensor_outputs,
attrs_outputs)
kernel_sig = _get_kernel_signature(
eager_tensor_inputs, eager_tensor_outputs, attrs_outputs)
if not kernel_sig:
return None
assert hasattr(
self, "python_api"
), "Please set the `self.python_api` if you want to compare python api output."
), "Detect there is KernelSignature for `%s` op, please set the `self.python_api` if you set check_eager = True" % self.op_type
args = prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, kernel_sig)
""" we directly return the cal_python_api value because the value is already tensor.
......@@ -1285,14 +1305,13 @@ class OpTest(unittest.TestCase):
place, no_check_set=no_check_set)
if check_eager:
# we only check end2end api when check_eager=True
with _test_eager_guard():
eager_dygraph_outs = self._calc_python_api_output(place)
if eager_dygraph_outs is None:
# missing KernelSignature, fall back to eager middle output.
eager_dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
# we only check end2end api when check_eager=True
if hasattr(self, "python_api"):
api_outs = self._calc_python_api_output(place)
self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs,
place)
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
......@@ -1827,7 +1846,7 @@ class OpTest(unittest.TestCase):
if check_dygraph:
dygraph_grad = self._get_dygraph_grad(
inputs_to_check, place, output_names, user_defined_grad_outputs,
no_grad_set)
no_grad_set, False)
fp32_grads = []
for grad in dygraph_grad:
if grad.dtype == np.uint16:
......@@ -1843,7 +1862,7 @@ class OpTest(unittest.TestCase):
with _test_eager_guard():
eager_dygraph_grad = self._get_dygraph_grad(
inputs_to_check, place, output_names,
user_defined_grad_outputs, no_grad_set)
user_defined_grad_outputs, no_grad_set, check_eager)
fp32_grads = []
for grad in eager_dygraph_grad:
if grad.dtype == np.uint16:
......@@ -1869,7 +1888,8 @@ class OpTest(unittest.TestCase):
place,
output_names,
user_defined_grad_outputs=None,
no_grad_set=None):
no_grad_set=None,
check_eager=False):
with fluid.dygraph.base.guard(place=place):
block = fluid.default_main_program().global_block()
......@@ -1890,6 +1910,11 @@ class OpTest(unittest.TestCase):
if self.attrs[attrs_name] is not None:
attrs_outputs[attrs_name] = self.attrs[attrs_name]
if check_eager:
outputs = self._calc_python_api_output(place, inputs, outputs)
# if outputs is None, kernel sig is empty or other error is happens.
if not check_eager or outputs is None:
block.append_op(
type=self.op_type,
inputs=inputs,
......
......@@ -1048,7 +1048,7 @@ class TestAbs(TestActivation):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=False)
class TestCeil(TestActivation):
......
......@@ -29,6 +29,7 @@ from paddle.fluid import Program, program_guard, core
paddle.enable_static()
#cholesky_solve implement 1
def cholesky_solution(X, B, upper=True):
if upper:
A = np.triu(X)
......@@ -43,6 +44,7 @@ def cholesky_solution(X, B, upper=True):
L, B, lower=True))
#cholesky_solve implement 2
def scipy_cholesky_solution(X, B, upper=True):
if upper:
umat = np.triu(X)
......@@ -54,27 +56,29 @@ def scipy_cholesky_solution(X, B, upper=True):
return scipy.linalg.cho_solve(K, B)
def boardcast_shape(matA, matB):
#broadcast function used by cholesky_solve
def broadcast_shape(matA, matB):
shapeA = matA.shape
shapeB = matB.shape
Boardshape = []
Broadshape = []
for idx in range(len(shapeA) - 2):
if shapeA[idx] == shapeB[idx]:
Boardshape.append(shapeA[idx])
Broadshape.append(shapeA[idx])
continue
elif shapeA[idx] == 1 or shapeB[idx] == 1:
Boardshape.append(max(shapeA[idx], shapeB[idx]))
Broadshape.append(max(shapeA[idx], shapeB[idx]))
else:
raise Exception(
'shapeA and shapeB should be boardcasted, but got {} and {}'.
'shapeA and shapeB should be broadcasted, but got {} and {}'.
format(shapeA, shapeB))
bsA = Boardshape + list(shapeA[-2:])
bsB = Boardshape + list(shapeB[-2:])
bsA = Broadshape + list(shapeA[-2:])
bsB = Broadshape + list(shapeB[-2:])
return np.broadcast_to(matA, bsA), np.broadcast_to(matB, bsB)
#cholesky_solve implement in batch
def scipy_cholesky_solution_batch(bumat, bB, upper=True):
bumat, bB = boardcast_shape(bumat, bB)
bumat, bB = broadcast_shape(bumat, bB)
ushape = bumat.shape
bshape = bB.shape
bumat = bumat.reshape((-1, ushape[-2], ushape[-1]))
......@@ -90,18 +94,21 @@ def scipy_cholesky_solution_batch(bumat, bB, upper=True):
return np.array(bx).reshape(bshape)
# 2D + 2D , , upper=False
# test condition: shape: 2D + 2D , upper=False
# based on OpTest class
class TestCholeskySolveOp(OpTest):
"""
case 1
"""
#test condition set
def config(self):
self.y_shape = [15, 15]
self.x_shape = [15, 5]
self.upper = False
self.dtype = np.float64
self.dtype = np.float64 #Here cholesky_solve Op only supports float64/float32 type, please check others if Op supports more types.
#get scipy result
def set_output(self):
umat = self.inputs['Y']
self.output = scipy_cholesky_solution_batch(
......@@ -125,14 +132,16 @@ class TestCholeskySolveOp(OpTest):
self.set_output()
self.outputs = {'Out': self.output}
#check Op forward result
def test_check_output(self):
self.check_output(check_eager=True)
#check Op grad
def test_check_grad_normal(self):
self.check_grad(['Y'], 'Out', max_relative_error=0.01, check_eager=True)
# 3D(broadcast) + 3D, upper=True
# test condition: 3D(broadcast) + 3D, upper=True
class TestCholeskySolveOp3(TestCholeskySolveOp):
"""
case 3
......@@ -145,11 +154,11 @@ class TestCholeskySolveOp3(TestCholeskySolveOp):
self.dtype = np.float64
#API function test
class TestCholeskySolveAPI(unittest.TestCase):
def setUp(self):
np.random.seed(2021)
self.place = [paddle.CPUPlace()]
# self.place = [paddle.CUDAPlace(0)]
self.dtype = "float64"
self.upper = True
if core.is_compiled_with_cuda():
......@@ -178,10 +187,12 @@ class TestCholeskySolveAPI(unittest.TestCase):
fetch_list=[z])
self.assertTrue(np.allclose(fetches[0], z_np))
#test in static mode
def test_static(self):
for place in self.place:
self.check_static_result(place=place)
#test in dynamic mode
def test_dygraph(self):
def run(place):
paddle.disable_static(place)
......@@ -200,7 +211,8 @@ class TestCholeskySolveAPI(unittest.TestCase):
for idx, place in enumerate(self.place):
run(place)
def test_boardcast(self):
#test input with broadcast
def test_broadcast(self):
def run(place):
paddle.disable_static()
x_np = np.random.random([1, 30, 2]).astype(self.dtype)
......@@ -219,6 +231,7 @@ class TestCholeskySolveAPI(unittest.TestCase):
run(place)
#test condition out of bounds
class TestCholeskySolveOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
......
......@@ -43,11 +43,11 @@ class TestDiagV2Op(OpTest):
def test_check_output(self):
paddle.enable_static()
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
paddle.enable_static()
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=False)
def init_config(self):
pass
......
......@@ -30,6 +30,7 @@ paddle.enable_static()
class TestDiagonalOp(OpTest):
def setUp(self):
self.op_type = "diagonal"
self.python_api = paddle.diagonal
self.init_config()
self.outputs = {'Out': self.target}
......
......@@ -29,6 +29,7 @@ class TestDigammaOp(OpTest):
paddle.enable_static()
self.op_type = 'digamma'
self.python_api = paddle.digamma
self.init_dtype_type()
shape = (5, 32)
data = np.random.random(shape).astype(self.dtype) + 1
......
......@@ -41,7 +41,8 @@ class TestElementwiseAddOp(OpTest):
self.outputs = {'Out': self.out}
def check_eager(self):
return (self.use_mkldnn == False and self.axis == -1)
return False
#return (self.use_mkldnn == False and self.axis == -1)
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
......
......@@ -34,10 +34,10 @@ class TestGatherNdOpWithEmptyIndex(OpTest):
}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=False)
class TestGatherNdOpWithIndex1(OpTest):
......@@ -49,10 +49,10 @@ class TestGatherNdOpWithIndex1(OpTest):
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=False)
class TestGatherNdOpWithLowIndex(OpTest):
......@@ -69,10 +69,10 @@ class TestGatherNdOpWithLowIndex(OpTest):
self.outputs = {'Out': xnp[tuple(index.T)]} #[[14, 25, 1], [76, 22, 3]]
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=False)
class TestGatherNdOpIndex1(OpTest):
......@@ -89,10 +89,10 @@ class TestGatherNdOpIndex1(OpTest):
self.outputs = {'Out': xnp[tuple(index.T)]}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=False)
class TestGatherNdOpWithSameIndexAsX(OpTest):
......@@ -108,10 +108,10 @@ class TestGatherNdOpWithSameIndexAsX(OpTest):
self.outputs = {'Out': xnp[tuple(index.T)]} #[25, 22]
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=False)
class TestGatherNdOpWithHighRankSame(OpTest):
......@@ -128,10 +128,10 @@ class TestGatherNdOpWithHighRankSame(OpTest):
self.outputs = {'Out': xnp[tuple(index.T)]}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=False)
class TestGatherNdOpWithHighRankDiff(OpTest):
......@@ -149,10 +149,10 @@ class TestGatherNdOpWithHighRankDiff(OpTest):
self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=False)
#Test Python API
......
......@@ -17,10 +17,11 @@ from __future__ import print_function
import unittest
import paddle.fluid as fluid
import numpy as np
from paddle.fluid.framework import _test_eager_guard
class TestImperativePartitialBackward(unittest.TestCase):
def test_partitial_backward(self):
def func_partitial_backward(self):
with fluid.dygraph.guard():
x = np.random.randn(2, 4, 5).astype("float32")
x = fluid.dygraph.to_variable(x)
......@@ -49,6 +50,11 @@ class TestImperativePartitialBackward(unittest.TestCase):
linear1.clear_gradients()
linear2.clear_gradients()
def test_partitial_backward(self):
with _test_eager_guard():
self.func_partitial_backward()
self.func_partitial_backward()
if __name__ == '__main__':
unittest.main()
......@@ -40,10 +40,10 @@ class TestIndexSampleOp(OpTest):
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=False)
def config(self):
"""
......
......@@ -105,14 +105,14 @@ class TestMatMulV2Op(OpTest):
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
if core.is_compiled_with_rocm():
self.check_grad(
['X', 'Y'], 'Out', max_relative_error=1e-2, check_eager=True)
['X', 'Y'], 'Out', max_relative_error=1e-2, check_eager=False)
else:
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
class TestMatMulOp2(TestMatMulV2Op):
......@@ -346,7 +346,7 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(
place, atol=atol, check_eager=True)
place, atol=atol, check_eager=False)
def test_check_grad(self):
place = core.CUDAPlace(0)
......@@ -355,7 +355,7 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0):
place, ['X', 'Y'],
'Out',
max_relative_error=max_relative_error,
check_eager=True)
check_eager=False)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
TestMatMulOpFp16Case.__name__ = cls_name
......@@ -534,7 +534,7 @@ class TestComplexMatMulOp(OpTest):
self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad_normal(self):
self.check_grad(
......@@ -542,7 +542,7 @@ class TestComplexMatMulOp(OpTest):
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out],
check_eager=True)
check_eager=False)
def test_check_grad_ingore_x(self):
self.check_grad(
......@@ -551,7 +551,7 @@ class TestComplexMatMulOp(OpTest):
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out],
check_eager=True)
check_eager=False)
def test_check_grad_ingore_y(self):
self.check_grad(
......@@ -560,7 +560,7 @@ class TestComplexMatMulOp(OpTest):
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out],
check_eager=True)
check_eager=False)
class TestComplexMatMulOpBroadcast(OpTest):
......@@ -598,7 +598,7 @@ class TestComplexMatMulOpBroadcast(OpTest):
axis=0)
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad_normal(self):
self.check_grad(
......@@ -606,7 +606,7 @@ class TestComplexMatMulOpBroadcast(OpTest):
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out],
check_eager=True)
check_eager=False)
def test_check_grad_ingore_x(self):
self.check_grad(
......@@ -615,7 +615,7 @@ class TestComplexMatMulOpBroadcast(OpTest):
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out],
check_eager=True)
check_eager=False)
def test_check_grad_ingore_y(self):
self.check_grad(
......@@ -624,7 +624,7 @@ class TestComplexMatMulOpBroadcast(OpTest):
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out],
check_eager=True)
check_eager=False)
class TestMatMulTypePromotion(TestComplexMatMulOp):
......
......@@ -77,10 +77,10 @@ class TestScatterNdAddSimpleOp(OpTest):
self.outputs = {'Out': expect_np}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_eager=True)
self.check_grad(['X', 'Updates'], 'Out', check_eager=False)
class TestScatterNdAddWithEmptyIndex(OpTest):
......@@ -101,10 +101,10 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
self.outputs = {'Out': expect_np}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_eager=True)
self.check_grad(['X', 'Updates'], 'Out', check_eager=False)
class TestScatterNdAddWithHighRankSame(OpTest):
......@@ -128,10 +128,10 @@ class TestScatterNdAddWithHighRankSame(OpTest):
self.outputs = {'Out': expect_np}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_eager=True)
self.check_grad(['X', 'Updates'], 'Out', check_eager=False)
class TestScatterNdAddWithHighRankDiff(OpTest):
......@@ -154,10 +154,10 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
self.outputs = {'Out': expect_np}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_eager=True)
self.check_grad(['X', 'Updates'], 'Out', check_eager=False)
#Test Python API
......
......@@ -37,10 +37,10 @@ class TestScatterOp(OpTest):
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out", check_eager=True)
self.check_grad(["X", "Updates"], "Out", check_eager=False)
class TestScatterOp0(OpTest):
......@@ -57,10 +57,10 @@ class TestScatterOp0(OpTest):
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out", check_eager=True)
self.check_grad(["X", "Updates"], "Out", check_eager=False)
class TestScatterOp1(OpTest):
......@@ -80,10 +80,10 @@ class TestScatterOp1(OpTest):
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out", check_eager=True)
self.check_grad(["X", "Updates"], "Out", check_eager=False)
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -103,13 +103,13 @@ class TestScatterOp2(OpTest):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3, check_eager=True)
self.check_output_with_place(place, atol=1e-3, check_eager=False)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_eager=True)
place, ['X', 'Updates'], 'Out', check_eager=False)
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -133,13 +133,13 @@ class TestScatterOp3(OpTest):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3, check_eager=True)
self.check_output_with_place(place, atol=1e-3, check_eager=False)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_eager=True)
place, ['X', 'Updates'], 'Out', check_eager=False)
class TestScatterOp4(OpTest):
......@@ -155,10 +155,10 @@ class TestScatterOp4(OpTest):
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_eager=True)
self.check_grad(['X', 'Updates'], 'Out', check_eager=False)
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -178,13 +178,13 @@ class TestScatterOp5(OpTest):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3, check_eager=True)
self.check_output_with_place(place, atol=1e-3, check_eager=False)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_eager=True)
place, ['X', 'Updates'], 'Out', check_eager=False)
class TestScatterAPI(unittest.TestCase):
......
......@@ -20,6 +20,8 @@ import numpy as np
import paddle
import paddle.nn as nn
from paddle.fluid.framework import _test_eager_guard, _in_eager_mode
import paddle.fluid as fluid
import paddle.fluid.core as core
class SimpleNet(nn.Layer):
......@@ -445,8 +447,7 @@ class TestTensorRegisterHook(unittest.TestCase):
self.func_multiple_hooks_for_interior_var()
self.func_multiple_hooks_for_interior_var()
# TODO(wuweilong): enable this case when DoubleGrad in eager mode is ready
def test_hook_in_double_grad(self):
def func_hook_in_double_grad(self):
def double_print_hook(grad):
grad = grad * 2
print(grad)
......@@ -461,10 +462,11 @@ class TestTensorRegisterHook(unittest.TestCase):
x.register_hook(double_print_hook)
y = x * x
fluid.set_flags({'FLAGS_retain_grad_for_all_tensor': False})
# Since y = x * x, dx = 2 * x
dx = paddle.grad(
outputs=[y], inputs=[x], create_graph=True, retain_graph=True)[0]
fluid.set_flags({'FLAGS_retain_grad_for_all_tensor': True})
z = y + dx
self.assertTrue(x.grad is None)
......@@ -475,9 +477,18 @@ class TestTensorRegisterHook(unittest.TestCase):
# x.gradient() = 2 * x + 2 = 4.0
# after changed by hook: 8.0
# TODO(wuweilong): enable this case when DoubleGrad in eager mode is ready
if core._in_eager_mode():
pass
else:
z.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), np.array([8.])))
def test_hook_in_double_grad(self):
with _test_eager_guard():
self.func_hook_in_double_grad()
self.func_hook_in_double_grad()
def func_remove_one_hook_multiple_times(self):
for device in self.devices:
paddle.set_device(device)
......
......@@ -29,6 +29,7 @@ paddle.enable_static()
class TestTruncOp(OpTest):
def setUp(self):
self.op_type = "trunc"
self.python_api = paddle.trunc
self.dtype = np.float64
np.random.seed(2021)
self.inputs = {'X': np.random.random((20, 20)).astype(self.dtype)}
......
......@@ -35,10 +35,10 @@ class TestWhereOp(OpTest):
self.outputs = {'Out': np.where(self.cond, self.x, self.y)}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
def init_config(self):
self.x = np.random.uniform((-3), 5, 100).astype('float64')
......
......@@ -109,7 +109,7 @@ class TestYoloBoxOp(OpTest):
self.outputs = {'Boxes': boxes, 'Scores': scores}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=False)
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
......
......@@ -243,6 +243,8 @@ def add(x, y, name=None):
"""
if paddle.in_dynamic_mode():
#if _in_eager_mode():
#return _C_ops.final_state_add(x, y)
return _C_ops.elementwise_add(x, y)
return _elementwise_op(LayerHelper('elementwise_add', **locals()))
......
......@@ -126,7 +126,7 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
keepdim(bool, optional): Keep the axis that selecting max. The defalut value is False.
keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.
dtype(str|np.dtype, optional): Data type of the output tensor which can
be int32, int64. The default value is 'int64', and it will
return the int64 indices.
......@@ -147,12 +147,15 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
[6,9,2,4]])
out1 = paddle.argmax(x)
print(out1) # 2
out2 = paddle.argmax(x, axis=1)
out2 = paddle.argmax(x, axis=0)
print(out2)
# [2 3 1]
# [2, 2, 0, 1]
out3 = paddle.argmax(x, axis=-1)
print(out3)
# [2 3 1]
# [2, 3, 1]
out4 = paddle.argmax(x, axis=0, keepdim=True)
print(out4)
# [[2, 2, 0, 1]]
"""
if axis is not None and not isinstance(axis, int):
raise TypeError(
......@@ -206,7 +209,7 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False.
keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.
dtype(str): Data type of the output tensor which can
be int32, int64. The default value is 'int64', and it will
return the int64 indices.
......@@ -227,12 +230,15 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
[6,9,2,4]])
out1 = paddle.argmin(x)
print(out1) # 4
out2 = paddle.argmin(x, axis=1)
out2 = paddle.argmin(x, axis=0)
print(out2)
# [0 0 2]
# [1, 1, 1, 2]
out3 = paddle.argmin(x, axis=-1)
print(out3)
# [0 0 2]
# [0, 0, 2]
out4 = paddle.argmin(x, axis=0, keepdim=True)
print(out4)
# [[1, 1, 1, 2]]
"""
if axis is not None and not isinstance(axis, int):
raise TypeError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册