提交 b8d52a82 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into enable_eager_model_test

......@@ -26,14 +26,14 @@ core_ops_args_type_info = {}
yaml_types_mapping = {
'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'Backend' : 'Backend', 'DataLayout' : 'DataLayout', 'DataType' : 'DataType', \
'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'Backend' : 'paddle::experimental::Backend', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor',
'Tensor[]' : 'std::vector<Tensor>',
'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>',
'Scalar' : 'Scalar',
'ScalarArray' : 'ScalarArray'
'Scalar' : 'paddle::experimental::Scalar',
'ScalarArray' : 'paddle::experimental::ScalarArray'
}
......@@ -208,39 +208,26 @@ def ParseYamlArgs(string):
def ParseYamlReturns(string):
# Example: Tensor, Tensor
# list = [ ["", ret_type, orig_position], ...]
returns_list = []
returns = [x.strip() for x in string.strip().split(",")]
for i in range(len(returns)):
ret_type = returns[i]
assert ret_type in yaml_types_mapping.keys()
ret_type = yaml_types_mapping[ret_type]
returns_list.append(["", ret_type, i])
return returns_list
def ParseYamlReturnsWithName(string):
# Example: Tensor(out), Tensor(out1)
# Example0: Tensor(out), Tensor(out1)
# Example1: Tensor, Tensor
# Example2: Tensor[](out), Tensor
# list = [ [ret_name, ret_type, orig_position], ...]
returns_list = []
returns = [x.strip() for x in string.strip().split(",")]
atype = r'(.*?)'
aname = r'(.*?)'
pattern = f'{atype}\({aname}\)'
for i in range(len(returns)):
ret = returns[i]
m = re.search(pattern, ret)
ret_type = m.group(1)
ret_name = m.group(2)
ret_name = ""
if "(" in ret and ")" in ret:
# Remove trailing ')'
ret = ret[:-1]
ret_type = ret.split("(")[0].strip()
ret_name = ret.split("(")[1].strip()
else:
ret_type = ret.strip()
assert ret_type in yaml_types_mapping.keys()
ret_type = yaml_types_mapping[ret_type]
......@@ -266,7 +253,7 @@ def ParseYamlForwardFromBackward(string):
function_returns = m.group(3)
forward_inputs_list, forward_attrs_list = ParseYamlArgs(function_args)
forward_returns_list = ParseYamlReturnsWithName(function_returns)
forward_returns_list = ParseYamlReturns(function_returns)
return forward_inputs_list, forward_attrs_list, forward_returns_list
......@@ -296,7 +283,7 @@ def ParseYamlBackward(args_str, returns_str):
args_str = re.search(args_pattern, args_str).group(1)
inputs_list, attrs_list = ParseYamlArgs(args_str)
returns_list = ParseYamlReturnsWithName(returns_str)
returns_list = ParseYamlReturns(returns_str)
return inputs_list, attrs_list, returns_list
......
......@@ -16,20 +16,26 @@ import os
import argparse
from eager_gen import yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
skipped_fwd_api_names = set(["scale"])
atype_to_parsing_function = {
"bool": "CastPyArg2Boolean",
"int": "CastPyArg2Int",
"long": "CastPyArg2Long",
"int64_t": "CastPyArg2Long",
"float": "CastPyArg2Float",
"string": "CastPyArg2String",
"bool[]": "CastPyArg2Booleans",
"int[]": "CastPyArg2Ints",
"long[]": "CastPyArg2Longs",
"float[]": "CastPyArg2Floats",
"double[]": "CastPyArg2Float64s",
"string[]": "CastPyArg2Strings",
"Scalar": "CastPyArg2Scalar",
"ScalarArray": "CastPyArg2ScalarArray"
"std::vector<bool>": "CastPyArg2Booleans",
"std::vector<int>": "CastPyArg2Ints",
"std::vector<long>": "CastPyArg2Longs",
"std::vector<int64_t>": "CastPyArg2Longs",
"std::vector<float>": "CastPyArg2Floats",
"std::vector<double>": "CastPyArg2Float64s",
"std::vector<std::string>": "CastPyArg2Strings",
"paddle::experimental::Scalar": "CastPyArg2Scalar",
"paddle::experimental::ScalarArray": "CastPyArg2ScalarArray",
"paddle::experimental::Backend": "CastPyArg2Backend",
"paddle::experimental::DataType": "CastPyArg2DataType",
}
......@@ -43,15 +49,9 @@ def ParseArguments():
return args
def GetCxxType(atype):
if atype not in yaml_types_mapping.keys():
assert False
return yaml_types_mapping[atype]
def FindParsingFunctionFromAttributeType(atype):
if atype not in atype_to_parsing_function.keys():
print(f"Unable to find {atype} in atype_to_parsing_function.")
assert False
return atype_to_parsing_function[atype]
......@@ -59,7 +59,7 @@ def FindParsingFunctionFromAttributeType(atype):
def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
forward_attrs_list, forward_outputs_position_map,
optional_inputs):
optional_inputs, is_forward_only):
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
......@@ -86,11 +86,10 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
# Get Attributes
for name, atype, _, pos in forward_attrs_list:
parsing_function = FindParsingFunctionFromAttributeType(atype)
cxx_type = GetCxxType(atype)
key = f"{name}"
parse_attributes_str += f" PyObject* {name}_obj = PyTuple_GET_ITEM(args, {pos});\n"
parse_attributes_str += f" {cxx_type} {name} = {parsing_function}({name}_obj, \"{fwd_api_name}\", {pos});\n"
parse_attributes_str += f" {atype} {name} = {parsing_function}({name}_obj, \"{fwd_api_name}\", {pos});\n"
dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list)
......@@ -127,9 +126,14 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
}}
"""
if is_forward_only:
fwd_function_name = fwd_api_name
else:
fwd_function_name = GetForwardFunctionName(fwd_api_name)
python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str,
GetForwardFunctionName(fwd_api_name), dygraph_function_call_str)
fwd_function_name, dygraph_function_call_str)
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}\n"
......@@ -213,6 +217,11 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
#pragma once
#include "pybind11/detail/common.h"
#include "paddle/phi/api/all.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/exception.h"
......@@ -251,19 +260,23 @@ if __name__ == "__main__":
python_c_function_list = []
python_c_function_reg_list = []
for fwd_api in fwd_api_list:
# We only generate Ops with grad
is_forward_only = False
if 'backward' not in fwd_api.keys():
continue
is_forward_only = True
assert 'api' in fwd_api.keys()
assert 'args' in fwd_api.keys()
assert 'output' in fwd_api.keys()
assert 'backward' in fwd_api.keys()
fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
if fwd_api_name in skipped_fwd_api_names:
continue
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
......@@ -285,7 +298,7 @@ if __name__ == "__main__":
python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction(
fwd_api_name, forward_inputs_position_map, forward_attrs_list,
forward_outputs_position_map, optional_inputs)
forward_outputs_position_map, optional_inputs, is_forward_only)
python_c_function_list.append(python_c_function_str)
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
USE_OP(conv2d);
USE_OP_ITSELF(conv2d);
USE_OP(conv2d_transpose);
namespace paddle {
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace paddle {
namespace operators {
......@@ -53,12 +54,11 @@ static inline void GetNCDHW(const framework::DDim& dims,
}
template <typename DeviceContext, typename T, size_t D>
static void RemovePaddingSlice(const framework::ExecutionContext& context,
static void RemovePaddingSlice(const phi::GPUContext& context,
const Tensor* input, Tensor* out,
const std::vector<int>& starts,
const std::vector<int>& axes) {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto& place = *context.eigen_device();
auto in_dims = input->dims();
auto new_out_dims = out->dims();
auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
......@@ -171,11 +171,10 @@ void ChooseAlgo(const std::vector<PerfType>& perf_results,
using framework::ConvSearchCache;
static void SetConvMathType(const framework::ExecutionContext& ctx,
cudnnDataType_t dtype,
static void SetConvMathType(const phi::GPUContext& ctx, cudnnDataType_t dtype,
const platform::ConvolutionDescriptor& cdesc) {
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto& dev_ctx = ctx;
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cdesc.desc(), CUDNN_TENSOR_OP_MATH));
......@@ -231,8 +230,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic,
const framework::ExecutionContext& ctx) {
bool deterministic, const phi::GPUContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool has_got_workspace_size = true;
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
......@@ -284,8 +282,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
} else if (deterministic) {
algo = static_cast<cudnnConvolutionFwdAlgo_t>(1);
} else {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto& dev_ctx = ctx;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache =
......@@ -346,8 +343,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic,
const framework::ExecutionContext& ctx) {
bool deterministic, const phi::GPUContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0;
......@@ -413,8 +409,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
} else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto& dev_ctx = ctx;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache =
......@@ -478,8 +473,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic,
const framework::ExecutionContext& ctx) {
bool deterministic, const phi::GPUContext& ctx) {
platform::CUDAGraphCaptureModeGuard guard;
auto dtype = platform::CudnnDataType<T>::type;
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
......@@ -534,8 +528,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
} else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto& dev_ctx = ctx;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetBackwardFilter());
......
此差异已折叠。
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace paddle {
namespace operators {
......@@ -51,12 +52,11 @@ static inline void GetNCDHW(const framework::DDim& dims,
}
template <typename DeviceContext, typename T, size_t D>
static void RemovePaddingSlice(const framework::ExecutionContext& context,
static void RemovePaddingSlice(const phi::GPUContext& context,
const Tensor* input, Tensor* out,
const std::vector<int>& starts,
const std::vector<int>& axes) {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto& place = *context.eigen_device();
auto in_dims = input->dims();
auto new_out_dims = out->dims();
auto offsets = Eigen::array<int, D>();
......@@ -128,11 +128,10 @@ struct SearchAlgorithm<miopenConvFwdAlgorithm_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, size_t workspace_size,
const framework::ExecutionContext& ctx) {
const phi::GPUContext& ctx) {
algo_t algo;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto workspace_handle = ctx.cudnn_workspace_handle();
int find_count;
miopenConvAlgoPerf_t find_result;
......@@ -170,11 +169,10 @@ struct SearchAlgorithm<miopenConvBwdDataAlgorithm_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, size_t workspace_size,
const framework::ExecutionContext& ctx) {
const phi::GPUContext& ctx) {
algo_t algo;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto workspace_handle = ctx.cudnn_workspace_handle();
int find_count;
miopenConvAlgoPerf_t find_result;
......@@ -212,11 +210,10 @@ struct SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, size_t workspace_size,
const framework::ExecutionContext& ctx) {
const phi::GPUContext& ctx) {
algo_t algo;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto workspace_handle = ctx.cudnn_workspace_handle();
int find_count;
miopenConvAlgoPerf_t find_result;
......
......@@ -205,14 +205,14 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
paddle::framework::DataTypeToString(input_data_type),
paddle::framework::DataTypeToString(filter_data_type)));
}
#ifndef PADDLE_WITH_ASCEND_CL
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(
library, framework::LibraryType::kCUDNN,
platform::errors::InvalidArgument(
"float16 can only be used when CUDNN or NPU is used"));
}
#endif
// #ifndef PADDLE_WITH_ASCEND_CL
// if (input_data_type == framework::proto::VarType::FP16) {
// PADDLE_ENFORCE_EQ(
// library, framework::LibraryType::kCUDNN,
// platform::errors::InvalidArgument(
// "float16 can only be used when CUDNN or NPU is used"));
// }
// #endif
#if PADDLE_WITH_CUDA
if (input_data_type == framework::proto::VarType::BF16 &&
library == framework::LibraryType::kCUDNN) {
......@@ -869,42 +869,6 @@ REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad,
ops::Conv3DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv3d_grad_grad, ops::ConvOpDoubleGrad);
// depthwise conv kernel
// TODO(xingzhaolong): neon kernel for mobile
REGISTER_OP_CPU_KERNEL(
depthwise_conv2d,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
depthwise_conv2d_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d, ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad_grad,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d, ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_grad_grad,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(conv2d)
.AddCheckpoint(
R"ROC(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
depthwise_conv2d,
ops::DepthwiseConvKernel<paddle::platform::CUDADeviceContext, float>,
ops::DepthwiseConvKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
depthwise_conv2d_grad,
ops::DepthwiseConvGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::DepthwiseConvGradKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
conv2d, ops::GemmConvKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
conv2d_grad,
ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
conv3d, ops::GemmConvKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
conv3d_grad,
ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, double>);
此差异已折叠。
......@@ -244,10 +244,14 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_HIP
using search = SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = std::max(workspace_size, search::GetWorkspaceSize(args));
algo = search::Find<T>(args, false, deterministic, workspace_size, ctx);
algo = search::Find<T>(
args, false, deterministic, workspace_size,
ctx.template device_context<platform::CUDADeviceContext>());
#else
using search = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
algo = search::Find<T>(args, false, deterministic, ctx);
algo = search::Find<T>(
args, false, deterministic,
ctx.template device_context<platform::CUDADeviceContext>());
workspace_size =
std::max(workspace_size, search::GetWorkspaceSize(args, algo));
#endif
......@@ -501,11 +505,14 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
using search1 = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size =
std::max(workspace_size, search1::GetWorkspaceSize(args1));
data_algo =
search1::Find<T>(args1, false, deterministic, workspace_size, ctx);
data_algo = search1::Find<T>(
args1, false, deterministic, workspace_size,
ctx.template device_context<platform::CUDADeviceContext>());
#else
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
data_algo = search1::Find<T>(args1, false, deterministic, ctx);
data_algo = search1::Find<T>(
args1, false, deterministic,
ctx.template device_context<platform::CUDADeviceContext>());
workspace_size =
std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo));
#endif
......@@ -523,11 +530,14 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
using search2 = SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2));
filter_algo =
search2::Find<T>(args2, false, deterministic, workspace_size, ctx);
filter_algo = search2::Find<T>(
args2, false, deterministic, workspace_size,
ctx.template device_context<platform::CUDADeviceContext>());
#else
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = search2::Find<T>(args2, false, deterministic, ctx);
filter_algo = search2::Find<T>(
args2, false, deterministic,
ctx.template device_context<platform::CUDADeviceContext>());
workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, filter_algo));
#endif
......@@ -944,11 +954,14 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_HIP
using search1 = SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = search1::GetWorkspaceSize(args1);
bwd_algo1 =
search1::Find<T>(args1, false, deterministic, workspace_size, ctx);
bwd_algo1 = search1::Find<T>(
args1, false, deterministic, workspace_size,
ctx.template device_context<platform::CUDADeviceContext>());
#else
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_algo1 = search1::Find<T>(args1, false, deterministic, ctx);
bwd_algo1 = search1::Find<T>(
args1, false, deterministic,
ctx.template device_context<platform::CUDADeviceContext>());
workspace_size = search1::GetWorkspaceSize(args1, bwd_algo1);
#endif
}
......@@ -965,11 +978,14 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
using search2 = SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2));
bwd_algo2 =
search2::Find<T>(args2, false, deterministic, workspace_size, ctx);
bwd_algo2 = search2::Find<T>(
args2, false, deterministic, workspace_size,
ctx.template device_context<platform::CUDADeviceContext>());
#else
using search2 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_algo2 = search2::Find<T>(args2, false, deterministic, ctx);
bwd_algo2 = search2::Find<T>(
args2, false, deterministic,
ctx.template device_context<platform::CUDADeviceContext>());
workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, bwd_algo2));
#endif
......@@ -990,11 +1006,14 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
using search3 = SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size =
std::max(workspace_size, search3::GetWorkspaceSize(args3));
filter_algo =
search3::Find<T>(args3, false, deterministic, workspace_size, ctx);
filter_algo = search3::Find<T>(
args3, false, deterministic, workspace_size,
ctx.template device_context<platform::CUDADeviceContext>());
#else
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = search3::Find<T>(args3, false, deterministic, ctx);
filter_algo = search3::Find<T>(
args3, false, deterministic,
ctx.template device_context<platform::CUDADeviceContext>());
workspace_size = std::max(workspace_size,
search3::GetWorkspaceSize(args3, filter_algo));
#endif
......@@ -1013,11 +1032,14 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel<T> {
using search4 = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4));
data_algo =
search4::Find<T>(args4, false, deterministic, workspace_size, ctx);
data_algo = search4::Find<T>(
args4, false, deterministic, workspace_size,
ctx.template device_context<platform::CUDADeviceContext>());
#else
using search4 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
data_algo = search4::Find<T>(args4, false, deterministic, ctx);
data_algo = search4::Find<T>(
args4, false, deterministic,
ctx.template device_context<platform::CUDADeviceContext>());
workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
#endif
......
......@@ -13,10 +13,150 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h"
#include "paddle/phi/kernels/gpu/depthwise_conv.h"
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
PADDLE_ENFORCE_EQ(
groups, filter.dims()[0],
platform::errors::InvalidArgument(
"groups should be error to the 1st dimension of filter. But "
"received groups is %d and filter dimension[0] is %d",
groups, filter.dims()[0]));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
for (auto v : dilations) {
PADDLE_ENFORCE_EQ(v, 1, platform::errors::InvalidArgument(
"dilations should be 1 in depthwise conv. "
"But received dilations is %d",
v));
}
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
output->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, output, static_cast<T>(0));
math::DepthwiseConvInputGradFunctor<phi::GPUContext, T>
depthwiseConvInputGrad;
depthwiseConvInputGrad(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*output, filter, *input, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, output, data_layout);
}
};
template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
Tensor filter = *context.Input<Tensor>("Filter");
if (!input_grad && !filter_grad) return;
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
if (input_grad) {
math::DepthwiseConvFunctor<phi::GPUContext, T> depthwiseConv;
depthwiseConv(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*output_grad, filter, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, input_grad, data_layout);
}
if (filter_grad) {
phi::funcs::SetConstant<DeviceContext, T> set_zero;
filter_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
math::DepthwiseConvFilterGradFunctor<phi::GPUContext, T>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*output_grad, *input, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, filter_grad, data_layout);
}
}
};
} // namespace operators
} // namespace paddle
// conv2d
REGISTER_OP_CUDA_KERNEL(conv2d_transpose,
ops::GemmConvTransposeKernel<CUDA, float>,
......
......@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
......@@ -578,130 +577,5 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
PADDLE_ENFORCE_EQ(
groups, filter.dims()[0],
platform::errors::InvalidArgument(
"groups should be error to the 1st dimension of filter. But "
"received groups is %d and filter dimension[0] is %d",
groups, filter.dims()[0]));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
for (auto v : dilations) {
PADDLE_ENFORCE_EQ(v, 1, platform::errors::InvalidArgument(
"dilations should be 1 in depthwise conv. "
"But received dilations is %d",
v));
}
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
output->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, output, static_cast<T>(0));
math::DepthwiseConvInputGradFunctor<DeviceContext, T>
depthwiseConvInputGrad;
depthwiseConvInputGrad(
dev_ctx, *output, filter, *input, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, output, data_layout);
}
};
template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const std::string data_layout_str =
context.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
Tensor filter = *context.Input<Tensor>("Filter");
if (!input_grad && !filter_grad) return;
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
if (input_grad) {
math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
depthwiseConv(
dev_ctx, *output_grad, filter, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, input_grad, data_layout);
}
if (filter_grad) {
phi::funcs::SetConstant<DeviceContext, T> set_zero;
filter_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(
dev_ctx, *output_grad, *input, strides,
std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
dilations, filter_grad, data_layout);
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef __NVCC__
#include <curand_kernel.h>
#endif
#ifdef __HIPCC__
#include <hiprand_kernel.h>
#endif
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/hostdevice.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
#endif
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
// there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition)
#endif
namespace paddle {
namespace distribution {
using Tensor = framework::Tensor;
/********************* Transformation Function **********************/
template <typename T>
struct exponential_transform {
explicit exponential_transform(T lambda) : lambda_(lambda) {}
HOSTDEVICE inline T operator()(T val) const {
#if defined(__NVCC__) || defined(__HIPCC__)
if (std::is_same<T, double>::value) {
return static_cast<T>(-1.0) / lambda_ * log(val);
} else {
return static_cast<T>(-1.0) / lambda_ * __logf(val);
}
#else
return static_cast<T>(-1.0) / lambda_ * std::log(static_cast<T>(1.0) - val);
#endif
}
private:
T lambda_;
};
template <typename T>
struct uniform_transform {
explicit uniform_transform(T min, T max) : range_(max - min), min_(min) {}
HOSTDEVICE inline T operator()(T val) const {
if (UNLIKELY(val == static_cast<T>(1.0))) {
return min_;
} else {
return val * range_ + min_;
}
}
private:
T range_;
T min_;
};
template <typename T>
struct normal_transform {
explicit normal_transform(T mean, T std) : mean_(mean), std_(std) {}
HOSTDEVICE inline T operator()(T val) const { return val * std_ + mean_; }
private:
T mean_;
T std_;
};
#if defined(__NVCC__) || defined(__HIPCC__)
namespace kps = phi::kps;
/*********************** Distribution Function *************************/
template <typename T>
struct uniform_distribution;
template <typename T>
struct normal_distribution;
#if defined(__NVCC__)
template <>
struct uniform_distribution<float> {
__device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const {
return curand_uniform4(state);
}
static constexpr int kReturnsCount = 4;
};
template <>
struct uniform_distribution<double> {
__device__ inline double2 operator()(
curandStatePhilox4_32_10_t *state) const {
return curand_uniform2_double(state);
}
static constexpr int kReturnsCount = 2;
};
template <>
struct normal_distribution<float> {
__device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const {
return curand_normal4(state);
}
static constexpr int kReturnsCount = 4;
};
template <>
struct normal_distribution<double> {
__device__ inline double2 operator()(
curandStatePhilox4_32_10_t *state) const {
return curand_normal2_double(state);
}
static constexpr int kReturnsCount = 2;
};
#else
template <>
struct uniform_distribution<float> {
__device__ inline float4 operator()(
hiprandStatePhilox4_32_10_t *state) const {
return hiprand_uniform4(state);
}
static constexpr int kReturnsCount = 4;
};
template <>
struct uniform_distribution<double> {
__device__ inline double2 operator()(
hiprandStatePhilox4_32_10_t *state) const {
return hiprand_uniform2_double(state);
}
static constexpr int kReturnsCount = 2;
};
template <>
struct normal_distribution<float> {
__device__ inline float4 operator()(
hiprandStatePhilox4_32_10_t *state) const {
return hiprand_normal4(state);
}
static constexpr int kReturnsCount = 4;
};
template <>
struct normal_distribution<double> {
__device__ inline double2 operator()(
hiprandStatePhilox4_32_10_t *state) const {
return hiprand_normal2_double(state);
}
static constexpr int kReturnsCount = 2;
};
#endif
/******** Launch GPU function of distribution and transformation *********/
template <typename T, typename DistOp, typename TransformOp>
__global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset,
DistOp dist, TransformOp trans, T *out_data,
size_t stride) {
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
static constexpr int kCount = DistOp::kReturnsCount;
#if defined(__NVCC__)
curandStatePhilox4_32_10_t state;
curand_init(seed, idx + THREAD_ID_X, offset, &state);
using SType = curandStatePhilox4_32_10_t;
#else
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx + THREAD_ID_X, offset, &state);
using SType = hiprandStatePhilox4_32_10_t;
#endif
size_t total_thread = GRID_NUM_X * BLOCK_NUM_X;
T args[kCount];
T result[kCount];
for (size_t i = idx; i < size; i += total_thread * kCount) {
kps::ElementwiseRandom<SType, T, kCount, 1, DistOp>(&args[0], dist, &state);
kps::ElementwiseUnary<T, T, kCount, 1, 1, TransformOp>(&result[0], &args[0],
trans);
kps::WriteData<T, T, kCount, 1, 1, true>(out_data + i, &result[0], size - i,
1, stride, 1);
__syncthreads();
}
}
template <typename T, typename DistOp, typename TransformOp>
void distribution_and_transform(const platform::CUDADeviceContext &dev_ctx,
Tensor *out, DistOp dist, TransformOp trans) {
T *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
auto size = out->numel();
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
size_t block_size = 256;
size_t expect_grid_size = (size + block_size - 1) / block_size;
const auto &prop = platform::GetDeviceProperties(device_id);
size_t max_grid_size = (prop.maxThreadsPerMultiProcessor / block_size) *
prop.multiProcessorCount;
size_t grid_size =
expect_grid_size > max_grid_size ? max_grid_size : expect_grid_size;
size_t total_thread = block_size * grid_size;
size_t curand4_loop_times =
(size + 4 * total_thread - 1) / (4 * total_thread);
// 'increment' shoulde be multiple of 4
uint64_t increment = curand4_loop_times * 4;
auto seed_offset = gen_cuda->IncrementOffset(increment);
uint64_t seed = seed_offset.first;
uint64_t offset = seed_offset.second;
DistributionKernel<
T, DistOp, TransformOp><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
size, seed, offset, dist, trans, out_data, total_thread);
}
#endif
} // namespace distribution
} // namespace paddle
......@@ -34,8 +34,8 @@ limitations under the License. */
#include "paddle/fluid/operators/dropout_impl_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/functors.h"
namespace paddle {
......@@ -86,8 +86,8 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
bool is_upscale_in_train,
uint64_t increment) {
using MT = typename details::MPTypeTrait<T>::Type;
using LoadT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
using LoadT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
#ifdef PADDLE_WITH_HIP
int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x;
......@@ -102,7 +102,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
LoadT src_val;
platform::Load<T, VecSize>(&src[i], &src_val);
phi::Load<T, VecSize>(&src[i], &src_val);
#ifdef PADDLE_WITH_HIP
float4 rand = hiprand_uniform4(&state);
......@@ -126,8 +126,8 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
}
}
platform::Store<T, VecSize>(dst_val, &dst[i]);
platform::Store<MaskType, VecSize>(mask_val, &mask[i]);
phi::Store<T, VecSize>(dst_val, &dst[i]);
phi::Store<MaskType, VecSize>(mask_val, &mask[i]);
}
}
......@@ -153,16 +153,16 @@ __global__ void DropoutGradCUDAKernel(
const typename details::MPTypeTrait<T>::Type factor, const int64_t size,
T* dx) {
using MT = typename details::MPTypeTrait<T>::Type;
using LoadT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
using LoadT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
LoadT dout_val;
platform::Load<T, VecSize>(&dout[i], &dout_val);
phi::Load<T, VecSize>(&dout[i], &dout_val);
MaskLoadT mask_val;
platform::Load<MaskType, VecSize>(&mask[i], &mask_val);
phi::Load<MaskType, VecSize>(&mask[i], &mask_val);
LoadT dx_val;
......@@ -172,7 +172,7 @@ __global__ void DropoutGradCUDAKernel(
static_cast<MT>(mask_val[j]) * factor);
}
platform::Store<T, VecSize>(dx_val, &dx[i]);
phi::Store<T, VecSize>(dx_val, &dx[i]);
}
}
......@@ -219,7 +219,7 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
uint64_t increment;
// VectorizedRandomGenerator use curand_uniform4, so we only support
// vec_size is 4;
int vec_size = (platform::GetVectorizedSize<T>(x_data) == 4) ? 4 : 1;
int vec_size = (phi::GetVectorizedSize<T>(x_data) == 4) ? 4 : 1;
auto gpu_config = GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size);
auto offset =
((x_numel - 1) / (gpu_config.GetThreadNum() * vec_size) + 1) * vec_size;
......
......@@ -76,7 +76,7 @@ class ExponentialKernel<platform::CPUDeviceContext, T>
auto engine = gen->GetCPUEngine();
std::uniform_real_distribution<T> uniform(0.0, 1.0);
distribution::exponential_transform<T> trans(lambda);
phi::funcs::exponential_transform<T> trans(lambda);
for (int64_t i = 0; i < size; ++i) {
out_data[i] = trans(uniform(*engine));
}
......
......@@ -26,9 +26,9 @@ class ExponentialKernel<platform::CUDADeviceContext, T>
auto& dev_cxt = ctx.template device_context<platform::CUDADeviceContext>();
T lambda = static_cast<T>(ctx.Attr<float>("lambda"));
distribution::uniform_distribution<T> dist;
distribution::exponential_transform<T> trans(lambda);
distribution::distribution_and_transform<T>(dev_cxt, out, dist, trans);
phi::funcs::uniform_distribution<T> dist;
phi::funcs::exponential_transform<T> trans(lambda);
phi::funcs::distribution_and_transform<T>(dev_cxt, out, dist, trans);
}
};
......
......@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distribution_helper.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......
......@@ -89,9 +89,9 @@ __global__ void BroadcastKernelBinary(
template <typename T>
void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
const T* in0, const T* in1, T* out) {
int in_vec_size = std::min(platform::GetVectorizedSize<T>(in0),
platform::GetVectorizedSize<T>(in1));
int out_vec_size = std::min(4, platform::GetVectorizedSize<T>(out));
int in_vec_size =
std::min(phi::GetVectorizedSize<T>(in0), phi::GetVectorizedSize<T>(in1));
int out_vec_size = std::min(4, phi::GetVectorizedSize<T>(out));
int vec_size = std::min(out_vec_size, in_vec_size);
int numel = m * n;
......@@ -191,9 +191,9 @@ void SetConfigForColumnReduce(const int max_threads, const int reduce_num,
int num_block = (max_threads / left_num);
if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
*blocking_size = phi::kernels::details::GetLastPow2(reduce_num / num_block);
*blocking_size = phi::funcs::details::GetLastPow2(reduce_num / num_block);
if (*blocking_size <= 1) {
*blocking_size = phi::kernels::details::GetLastPow2(sqrt(reduce_num));
*blocking_size = phi::funcs::details::GetLastPow2(sqrt(reduce_num));
} else if (*blocking_size * 2 < reduce_num) {
*blocking_size *= 2;
}
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/fused/cudnn_norm_conv.cu.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace framework = paddle::framework;
......@@ -29,10 +30,10 @@ namespace platform = paddle::platform;
namespace op = paddle::operators;
using Tensor = paddle::framework::Tensor;
USE_OP(conv2d);
USE_OP(conv2d_grad);
USE_OP_DEVICE_KERNEL(conv2d, CUDNN);
USE_OP_DEVICE_KERNEL(conv2d_grad, CUDNN);
USE_OP_ITSELF(conv2d);
USE_OP_ITSELF(conv2d_grad);
PD_DECLARE_KERNEL(conv2d, GPUDNN, ALL_LAYOUT);
PD_DECLARE_KERNEL(conv2d_grad, GPUDNN, ALL_LAYOUT);
template <typename T>
void InitRandomTensor(const std::vector<int64_t> &dims,
......
......@@ -130,17 +130,17 @@ __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout,
const T factor, const int64_t size, T *dx) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
LoadT dout_vec;
LoadT src_vec;
MaskLoadT mask_vec;
platform::Load<T, VecSize>(&dout[i], &dout_vec);
platform::Load<MaskType, VecSize>(&mask[i], &mask_vec);
platform::Load<T, VecSize>(&src[i], &src_vec);
phi::Load<T, VecSize>(&dout[i], &dout_vec);
phi::Load<MaskType, VecSize>(&mask[i], &mask_vec);
phi::Load<T, VecSize>(&src[i], &src_vec);
StoreT dx_vec;
#pragma unroll
......@@ -148,7 +148,7 @@ __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout,
T tmp = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
dx_vec[ii] = tmp * act_grad.UseOut(src_vec[ii]);
}
platform::Store<T, VecSize>(dx_vec, &dx[i]);
phi::Store<T, VecSize>(dx_vec, &dx[i]);
}
}
......@@ -167,9 +167,9 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout,
T *dx, T *dbias) {
int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
T tmp_sum[VecSize] = {static_cast<T>(0)};
// calculate the dx and temporary sum
if (col_id * VecSize < cols) {
......@@ -180,10 +180,10 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout,
LoadT bias_vec;
MaskLoadT mask_vec;
platform::Load<T, VecSize>(&dout[index], &dout_vec);
platform::Load<T, VecSize>(&src[index], &src_vec);
platform::Load<MaskType, VecSize>(&mask[index], &mask_vec);
platform::Load<T, VecSize>(&bias[col_id * VecSize], &bias_vec);
phi::Load<T, VecSize>(&dout[index], &dout_vec);
phi::Load<T, VecSize>(&src[index], &src_vec);
phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
phi::Load<T, VecSize>(&bias[col_id * VecSize], &bias_vec);
StoreT dx_vec;
#pragma unroll
......@@ -194,7 +194,7 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout,
dx_vec[i] = val;
tmp_sum[i] += val;
}
platform::Store<T, VecSize>(dx_vec, &dx[index]);
phi::Store<T, VecSize>(dx_vec, &dx[index]);
}
}
......
......@@ -21,11 +21,11 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/functors.h"
namespace paddle {
......
......@@ -42,12 +42,12 @@ __device__ void CalcLayernormY(
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *bias, const T *x,
T *y, const int row_id, const int col_id, const int cols,
const LayerNormParamType<T> mean_val, const LayerNormParamType<T> invvar) {
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using LoadU = platform::AlignedVector<U, VecSize>;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using LoadU = phi::AlignedVector<U, VecSize>;
using LoadScaleOrBias =
platform::AlignedVector<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>,
VecSize>;
phi::AlignedVector<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>,
VecSize>;
for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) {
LoadScaleOrBias scale_vec;
LoadScaleOrBias bias_vec;
......@@ -60,15 +60,15 @@ __device__ void CalcLayernormY(
static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(0);
}
// vectorize load data from global
platform::Load<T, VecSize>(&x[row_id * cols + i], &x_vec);
phi::Load<T, VecSize>(&x[row_id * cols + i], &x_vec);
if (scale != nullptr) {
platform::Load<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>,
VecSize>(&scale[i], &scale_vec);
phi::Load<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, VecSize>(
&scale[i], &scale_vec);
}
if (bias != nullptr) {
platform::Load<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>,
VecSize>(&bias[i], &bias_vec);
phi::Load<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, VecSize>(
&bias[i], &bias_vec);
}
StoreT y_vec;
......@@ -78,7 +78,7 @@ __device__ void CalcLayernormY(
(static_cast<U>(x_vec[ii]) - mean_val) * invvar +
static_cast<U>(bias_vec[ii]));
}
platform::Store<T, VecSize>(y_vec, &y[row_id * cols + i]);
phi::Store<T, VecSize>(y_vec, &y[row_id * cols + i]);
}
}
......@@ -190,9 +190,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
const ScaleT *__restrict__ beta_ptr, MaskType *__restrict__ mask_out_ptr,
U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr,
T *__restrict__ residual_out_ptr, T *__restrict__ y_ptr) {
using Vec = platform::AlignedVector<T, VecSize>;
using Vec_scale = platform::AlignedVector<ScaleT, VecSize>;
using MaskStoreT = platform::AlignedVector<MaskType, VecSize>;
using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
......@@ -214,8 +214,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
Vec_scale beta[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
platform::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
phi::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
phi::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
col += THREADS_PER_ROW;
}
......@@ -225,10 +225,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
Vec residual[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize,
&x[it]);
platform::Load<T, VecSize>(
residual_ptr + row * LN_NUM_COLS + col * VecSize, &residual[it]);
phi::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]);
phi::Load<T, VecSize>(residual_ptr + row * LN_NUM_COLS + col * VecSize,
&residual[it]);
col += THREADS_PER_ROW;
}
......@@ -270,9 +269,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
// store dropout_residual_out and mask_out
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Store<T, VecSize>(
phi::Store<T, VecSize>(
x[it], residual_out_ptr + row * LN_NUM_COLS + col * VecSize);
platform::Store<MaskType, VecSize>(
phi::Store<MaskType, VecSize>(
mask_vec[it], mask_out_ptr + row * LN_NUM_COLS + col * VecSize);
col += THREADS_PER_ROW;
}
......@@ -333,8 +332,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Store<T, VecSize>(x[it],
y_ptr + row * LN_NUM_COLS + col * VecSize);
phi::Store<T, VecSize>(x[it], y_ptr + row * LN_NUM_COLS + col * VecSize);
col += THREADS_PER_ROW;
}
}
......
......@@ -32,9 +32,9 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
const T *__restrict__ bias, T *dst, MaskType *mask, const bool is_test,
typename details::MPTypeTrait<T>::Type *mean_val,
typename details::MPTypeTrait<T>::Type *var_val, Functor act_func) {
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskStoreT = platform::AlignedVector<MaskType, VecSize>;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
using U = typename details::MPTypeTrait<T>::Type;
LoadT src_vec;
......@@ -46,14 +46,13 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
residual_vec[ii] = static_cast<T>(0);
}
// vectorize load data from global
platform::Load<T, VecSize>(&src[row_id * cols + col_id], &src_vec);
phi::Load<T, VecSize>(&src[row_id * cols + col_id], &src_vec);
if (residual) {
platform::Load<T, VecSize>(&residual[row_id * cols + col_id],
&residual_vec);
phi::Load<T, VecSize>(&residual[row_id * cols + col_id], &residual_vec);
}
if (bias) {
platform::Load<T, VecSize>(&bias[col_id], &bias_vec);
phi::Load<T, VecSize>(&bias[col_id], &bias_vec);
}
MaskStoreT mask_vec;
......@@ -89,9 +88,9 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
}
// store result to global
platform::Store<T, VecSize>(dest_vec, &dst[row_id * cols + col_id]);
phi::Store<T, VecSize>(dest_vec, &dst[row_id * cols + col_id]);
if (!is_test) {
platform::Store<MaskType, VecSize>(mask_vec, &mask[row_id * cols + col_id]);
phi::Store<MaskType, VecSize>(mask_vec, &mask[row_id * cols + col_id]);
}
}
......@@ -176,21 +175,21 @@ __global__ void FusedResidualDropoutGrad(const T *dout, const MaskType *mask,
T *dx) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
LoadT dout_vec;
MaskLoadT mask_vec;
platform::Load<T, VecSize>(&dout[i], &dout_vec);
platform::Load<MaskType, VecSize>(&mask[i], &mask_vec);
phi::Load<T, VecSize>(&dout[i], &dout_vec);
phi::Load<MaskType, VecSize>(&mask[i], &mask_vec);
StoreT dx_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
}
platform::Store<T, VecSize>(dx_vec, &dx[i]);
phi::Store<T, VecSize>(dx_vec, &dx[i]);
}
}
......@@ -209,9 +208,9 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout,
T *dbias) {
int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
T tmp_sum[VecSize] = {static_cast<T>(0)};
// calculate the dx and temporary sum
......@@ -221,8 +220,8 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout,
LoadT out_vec;
MaskLoadT mask_vec;
StoreT dx_vec;
platform::Load<T, VecSize>(&dout[index], &out_vec);
platform::Load<MaskType, VecSize>(&mask[index], &mask_vec);
phi::Load<T, VecSize>(&dout[index], &out_vec);
phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
......@@ -230,7 +229,7 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout,
tmp_sum[i] += out_vec[i];
}
platform::Store<T, VecSize>(dx_vec, &dx[index]);
phi::Store<T, VecSize>(dx_vec, &dx[index]);
}
}
......
......@@ -19,9 +19,10 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/distribution_helper.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/index_impl.cu.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
DECLARE_bool(use_curand);
......@@ -79,10 +80,10 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
int64_t gen_offset = size * seed_offset.second;
auto func = GaussianGenerator<T>(mean, std, seed_offset.first,
seed_offset.second);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
phi::IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
} else {
auto func = GaussianGenerator<T>(mean, std, seed);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
phi::IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
}
}
};
......
......@@ -58,7 +58,7 @@ static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x, __half* y,
static_cast<size_t>(threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
size_t stride = static_cast<size_t>(blockDim.x * gridDim.x) * VecSize;
for (; offset < n; offset += stride) {
using ArrT = platform::AlignedVector<__half, VecSize>;
using ArrT = phi::AlignedVector<__half, VecSize>;
ArrT in_arr = *reinterpret_cast<const ArrT*>(x + offset);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
......@@ -77,7 +77,7 @@ static __global__ void FP16FastGeluBwdCUDAKernel(const __half* x,
static_cast<size_t>(threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
size_t stride = static_cast<size_t>(blockDim.x * gridDim.x) * VecSize;
for (; offset < n; offset += stride) {
using ArrT = platform::AlignedVector<__half, VecSize>;
using ArrT = phi::AlignedVector<__half, VecSize>;
ArrT x_in_arr = *reinterpret_cast<const ArrT*>(x + offset);
ArrT y_g_in_arr = *reinterpret_cast<const ArrT*>(y_g + offset);
#pragma unroll
......@@ -103,7 +103,7 @@ static bool TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(
#define PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(__vec_size, __use_fast_math) \
do { \
constexpr auto kAlignment = \
alignof(platform::AlignedVector<__half, __vec_size>); \
alignof(phi::AlignedVector<__half, __vec_size>); \
if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \
is_aligned(y, kAlignment)) { \
size_t thread = std::min<size_t>(512, dev_ctx.GetMaxThreadsPerBlock()); \
......@@ -138,7 +138,7 @@ static bool TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(
#define PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(__vec_size, __use_fast_math) \
do { \
constexpr auto kAlignment = \
alignof(platform::AlignedVector<__half, __vec_size>); \
alignof(phi::AlignedVector<__half, __vec_size>); \
if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \
is_aligned(x, kAlignment) && is_aligned(y_g, kAlignment) && \
is_aligned(x_g, kAlignment)) { \
......
......@@ -19,11 +19,11 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/distribution_helper.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace paddle {
......@@ -58,7 +58,7 @@ void IndexKernel(const KPDevice &dev_ctx, Tensor *out, Functor func) {
int numel = out->numel();
T *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (numel <= 0) return;
int vec_size = paddle::platform::GetVectorizedSize(out_data);
int vec_size = phi::GetVectorizedSize(out_data);
#ifdef PADDLE_WITH_XPU_KP
int block = 64;
int grid = 8;
......
......@@ -22,10 +22,10 @@ limitations under the License. */
namespace cub = hipcub;
#endif
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace paddle {
namespace operators {
......@@ -186,8 +186,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr,
U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr,
T *__restrict__ y_ptr) {
using Vec = platform::AlignedVector<T, VecSize>;
using Vec_scale = platform::AlignedVector<ScaleT, VecSize>;
using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
......@@ -203,8 +203,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
Vec_scale beta[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
platform::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
phi::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
phi::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
col += THREADS_PER_ROW;
}
......@@ -213,8 +213,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
Vec x[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize,
&x[it]);
phi::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]);
col += THREADS_PER_ROW;
}
U xf[LDGS * VecSize];
......@@ -276,8 +275,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Store<T, VecSize>(x[it],
y_ptr + row * LN_NUM_COLS + col * VecSize);
phi::Store<T, VecSize>(x[it], y_ptr + row * LN_NUM_COLS + col * VecSize);
col += THREADS_PER_ROW;
}
}
......@@ -401,9 +399,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
U *__restrict__ dgamma_temp_ptr, U *__restrict__ dbeta_temp_ptr,
T *__restrict__ dx_ptr, const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0), T *d_dropout_src_ptr = nullptr) {
using Vec = platform::AlignedVector<T, VecSize>;
using Vec_scale = platform::AlignedVector<ScaleT, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
......@@ -439,7 +437,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
int col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
platform::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
phi::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
col += THREADS_PER_ROW;
}
......@@ -452,12 +450,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
int col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
platform::Load<T, VecSize>(dout_ptr + row * LN_NUM_COLS + col * VecSize,
&dout[it]);
platform::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize,
&x[it]);
phi::Load<T, VecSize>(dout_ptr + row * LN_NUM_COLS + col * VecSize,
&dout[it]);
phi::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]);
if (isFusedDropoutResidualLn) {
platform::Load<MaskType, VecSize>(
phi::Load<MaskType, VecSize>(
mask_ptr + row * LN_NUM_COLS + col * VecSize, &mask_vec[it]);
}
......@@ -552,10 +549,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
platform::Store<T, VecSize>(x[it],
dx_ptr + row * LN_NUM_COLS + col * VecSize);
phi::Store<T, VecSize>(x[it], dx_ptr + row * LN_NUM_COLS + col * VecSize);
if (isFusedDropoutResidualLn) {
platform::Store<T, VecSize>(
phi::Store<T, VecSize>(
dout[it], d_dropout_src_ptr + row * LN_NUM_COLS + col * VecSize);
}
col += THREADS_PER_ROW;
......@@ -641,7 +637,7 @@ template <
__global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_,
ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) {
using Vec = platform::AlignedVector<U, VecSize>;
using Vec = phi::AlignedVector<U, VecSize>;
static_assert(VEC_COLS == LN_NUM_COLS / VecSize, "");
const int tidx = threadIdx.x;
......@@ -669,8 +665,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
for (int row = r; row < rows; row += ROWS_PER_CTA) {
Vec dg;
Vec db;
platform::Load<U, VecSize>(dg_part_ptr, &dg);
platform::Load<U, VecSize>(db_part_ptr, &db);
phi::Load<U, VecSize>(dg_part_ptr, &dg);
phi::Load<U, VecSize>(db_part_ptr, &db);
dg_part_ptr += ROWS_PER_CTA * LN_NUM_COLS;
db_part_ptr += ROWS_PER_CTA * LN_NUM_COLS;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
using DataLayout = framework::DataLayout;
/*
* \brief Compute the depthwise convolution which include
* forward process and backpropagation process
*/
template <typename DeviceContext, typename T,
bool fuse_relu_before_conv = false>
class DepthwiseConvFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations, framework::Tensor* output,
const DataLayout data_layout = DataLayout::kNCHW);
};
template <typename DeviceContext, typename T,
bool fuse_relu_before_conv = false>
class DepthwiseConvInputGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& filter,
const framework::Tensor& output_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
framework::Tensor* input_grad,
const DataLayout data_layout = DataLayout::kNCHW);
};
template <typename DeviceContext, typename T,
bool fuse_relu_before_conv = false>
class DepthwiseConvFilterGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
framework::Tensor* filter_grad,
const DataLayout data_layout = DataLayout::kNCHW);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
namespace paddle {
namespace platform {
class CPUDeviceContext;
......@@ -141,6 +143,116 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
}
};
template <class T>
class Vol2ColFunctor<phi::CPUContext, T> {
public:
void operator()(const phi::CPUContext& context, const framework::Tensor& vol,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
platform::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.",
vol.dims().size()));
PADDLE_ENFORCE_EQ(col->dims().size(), 7,
platform::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.",
col->dims().size()));
int input_channels =
(data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
int input_depth =
(data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
int input_height =
(data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
int input_width =
(data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
int output_depth = col->dims()[4];
int output_height = col->dims()[5];
int output_width = col->dims()[6];
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
// changed
bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1;
PADDLE_ENFORCE_EQ(
input_depth_tmp, output_depth,
platform::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, output_depth));
auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1;
PADDLE_ENFORCE_EQ(
input_height_tmp, output_height,
platform::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, output_height));
auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1;
PADDLE_ENFORCE_EQ(
input_width_tmp, output_width,
platform::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, output_width));
const T* vol_data = vol.data<T>();
T* col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int d_offset = (c / filter_width / filter_height) % filter_depth;
int c_in = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) {
int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
for (int h = 0; h < output_height; ++h) {
int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
for (int w = 0; w < output_width; ++w) {
int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2];
int col_idx =
((c * output_depth + d) * output_height + h) * output_width + w;
int vol_idx;
if (data_layout != DataLayout::kNHWC) {
vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
} else {
vol_idx = ((d_pad * input_height + h_pad) * input_width + w_pad) *
input_channels +
c_in;
}
col_data[col_idx] =
(h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
w_pad >= input_width || d_pad < 0 || d_pad >= input_depth)
? static_cast<T>(0)
: vol_data[vol_idx];
}
}
}
}
}
};
/*
* vol = [input_channels,input_depth, input_height, input_width]
* col =
......@@ -258,10 +370,125 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
}
};
template <class T>
class Col2VolFunctor<phi::CPUContext, T> {
public:
void operator()(const phi::CPUContext& context, const framework::Tensor& col,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
platform::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.",
vol->dims().size()));
PADDLE_ENFORCE_EQ(col.dims().size(), 7,
platform::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.",
col.dims().size()));
int input_channels =
(data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
int input_depth =
(data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
int input_height =
(data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
int input_width =
(data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
int output_depth = col.dims()[4];
int output_height = col.dims()[5];
int output_width = col.dims()[6];
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1;
PADDLE_ENFORCE_EQ(
input_depth_tmp, output_depth,
platform::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, output_depth));
auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1;
PADDLE_ENFORCE_EQ(
input_height_tmp, output_height,
platform::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, output_height));
auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1;
PADDLE_ENFORCE_EQ(
input_width_tmp, output_width,
platform::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, output_width));
T* vol_data = vol->data<T>();
const T* col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int d_offset = (c / filter_width / filter_height) % filter_depth;
int cIm = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) {
int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
for (int h = 0; h < output_height; ++h) {
int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
for (int w = 0; w < output_width; ++w) {
int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2];
if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
int vol_idx;
if (data_layout != DataLayout::kNHWC) {
vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
} else {
vol_idx =
((d_pad * input_height + h_pad) * input_width + w_pad) *
input_channels +
cIm;
}
int col_idx =
((c * output_depth + d) * output_height + h) * output_width +
w;
vol_data[vol_idx] += col_data[col_idx];
}
}
}
}
}
}
};
template class Vol2ColFunctor<platform::CPUDeviceContext, float>;
template class Vol2ColFunctor<platform::CPUDeviceContext, double>;
template class Vol2ColFunctor<phi::CPUContext, float>;
template class Vol2ColFunctor<phi::CPUContext, double>;
template class Col2VolFunctor<platform::CPUDeviceContext, float>;
template class Col2VolFunctor<platform::CPUDeviceContext, double>;
template class Col2VolFunctor<phi::CPUContext, float>;
template class Col2VolFunctor<phi::CPUContext, double>;
} // namespace math
} // namespace operators
......
......@@ -33,7 +33,7 @@ USE_OP(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP(conv2d);
USE_OP_ITSELF(conv2d);
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);
namespace paddle {
......@@ -55,7 +55,7 @@ class CacheTester {
onednn_dev_ctx_->ResetBlobMap(nullptr);
}
bool Analyze(unsigned short int num_entries) {
bool Analyze(uint16_t num_entries) {
// Number of created objects in cache should be as expected (num_entries)
return onednn_dev_ctx_->GetCachedObjectsNumber() == num_entries;
}
......
......@@ -57,8 +57,7 @@ static void LaunchCastKernel(const platform::CUDADeviceContext &ctx,
PADDLE_ENFORCE_NE(
static_cast<const void *>(x), static_cast<void *>(y),
platform::errors::InvalidArgument("Inplace cast is not supported yet."));
int vec_size =
std::min(platform::GetVectorizedSize(x), platform::GetVectorizedSize(y));
int vec_size = std::min(phi::GetVectorizedSize(x), phi::GetVectorizedSize(y));
switch (vec_size) {
case 4:
return details::VecCastKernel<InT, OutT, 4>(ctx, x, y, n);
......
......@@ -19,11 +19,11 @@
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h"
#include "paddle/fluid/operators/tensor_to_string.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
......@@ -66,8 +66,8 @@ struct L2NormFunctor {
int i;
for (i = threadIdx.x * VecSize; i + VecSize <= size;
i += (BlockDim * VecSize)) {
platform::AlignedVector<T, VecSize> tmp_vec;
platform::Load(ptr + i, &tmp_vec);
phi::AlignedVector<T, VecSize> tmp_vec;
phi::Load(ptr + i, &tmp_vec);
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
auto tmp = static_cast<MT>(tmp_vec[j]);
......@@ -111,9 +111,9 @@ static int GetChunkedVecSize(const T *ptr, int chunk_size) {
constexpr int max_load_bits = 128;
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
auto address = reinterpret_cast<uintptr_t>(ptr);
constexpr int vec8 = alignof(platform::AlignedVector<T, 8>);
constexpr int vec4 = alignof(platform::AlignedVector<T, 4>);
constexpr int vec2 = alignof(platform::AlignedVector<T, 2>);
constexpr int vec8 = alignof(phi::AlignedVector<T, 8>);
constexpr int vec4 = alignof(phi::AlignedVector<T, 4>);
constexpr int vec2 = alignof(phi::AlignedVector<T, 2>);
chunk_size *= sizeof(T);
if (address % vec8 == 0 && chunk_size % vec8 == 0) {
return std::min(8, valid_vec_size);
......@@ -316,15 +316,15 @@ static __global__ void ScaleCUDAKernel(const T1 *__restrict__ x,
int stride = blockDim.x * gridDim.x * VecSize;
for (; i + VecSize <= num; i += stride) {
platform::AlignedVector<T1, VecSize> x_vec;
platform::AlignedVector<T1, VecSize> y_vec;
phi::AlignedVector<T1, VecSize> x_vec;
phi::AlignedVector<T1, VecSize> y_vec;
platform::Load(x + i, &x_vec);
phi::Load(x + i, &x_vec);
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
y_vec[j] = static_cast<T1>(static_cast<T2>(x_vec[j]) * s);
}
platform::Store(y_vec, y + i);
phi::Store(y_vec, y + i);
}
for (; i < num; ++i) {
......@@ -410,24 +410,24 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
int stride = blockDim.x * gridDim.x * VecSize;
for (; i + VecSize <= num; i += stride) {
platform::AlignedVector<T, VecSize> param_vec;
platform::AlignedVector<GradT, VecSize> grad_vec;
platform::AlignedVector<T, VecSize> mom1_vec;
platform::AlignedVector<T, VecSize> mom2_vec;
platform::AlignedVector<T, VecSize> trust_ratio_div_vec;
phi::AlignedVector<T, VecSize> param_vec;
phi::AlignedVector<GradT, VecSize> grad_vec;
phi::AlignedVector<T, VecSize> mom1_vec;
phi::AlignedVector<T, VecSize> mom2_vec;
phi::AlignedVector<T, VecSize> trust_ratio_div_vec;
T cur_weight_decay = (i < weight_decay_end_numel) * weight_decay;
if (cur_weight_decay != static_cast<T>(0.0)) {
platform::Load(param_p + i, &param_vec);
phi::Load(param_p + i, &param_vec);
} else {
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
param_vec[j] = static_cast<T>(0);
}
}
platform::Load(grad_p + i, &grad_vec);
platform::Load(mom1_p + i, &mom1_vec);
platform::Load(mom2_p + i, &mom2_vec);
phi::Load(grad_p + i, &grad_vec);
phi::Load(mom1_p + i, &mom1_vec);
phi::Load(mom2_p + i, &mom2_vec);
#define PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(__param, __grad, __mom1, __mom2, \
__trust_ratio_div, __idx) \
......@@ -450,9 +450,9 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
mom2_vec, trust_ratio_div_vec, j);
}
platform::Store(mom1_vec, mom1_p + i);
platform::Store(mom2_vec, mom2_p + i);
platform::Store(trust_ratio_div_vec, trust_ratio_div_p + i);
phi::Store(mom1_vec, mom1_p + i);
phi::Store(mom2_vec, mom2_p + i);
phi::Store(trust_ratio_div_vec, trust_ratio_div_p + i);
}
for (; i < num; ++i) {
......@@ -632,29 +632,29 @@ struct LambUpdateParamAndBetaPowsFunctor {
trust_ratio_div += offset;
for (i = threadIdx.x * VecSize; i + VecSize <= size; i += stride) {
platform::AlignedVector<MT, VecSize> trust_ratio_div_vec;
platform::Load(trust_ratio_div + i, &trust_ratio_div_vec);
phi::AlignedVector<MT, VecSize> trust_ratio_div_vec;
phi::Load(trust_ratio_div + i, &trust_ratio_div_vec);
if (HasMasterParam) {
platform::AlignedVector<MT, VecSize> master_param_vec;
platform::Load(master_param + i, &master_param_vec);
platform::AlignedVector<ParamT, VecSize> param_vec;
phi::AlignedVector<MT, VecSize> master_param_vec;
phi::Load(master_param + i, &master_param_vec);
phi::AlignedVector<ParamT, VecSize> param_vec;
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
MT p = master_param_vec[j] - ratio * trust_ratio_div_vec[j];
master_param_vec[j] = p;
param_vec[j] = static_cast<ParamT>(p);
}
platform::Store(master_param_vec, master_param + i);
platform::Store(param_vec, param + i);
phi::Store(master_param_vec, master_param + i);
phi::Store(param_vec, param + i);
} else {
platform::AlignedVector<ParamT, VecSize> param_vec;
platform::Load(param + i, &param_vec);
phi::AlignedVector<ParamT, VecSize> param_vec;
phi::Load(param + i, &param_vec);
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
MT p = static_cast<MT>(param_vec[j]) - ratio * trust_ratio_div_vec[j];
param_vec[j] = static_cast<ParamT>(p);
}
platform::Store(param_vec, param + i);
phi::Store(param_vec, param + i);
}
}
......
......@@ -88,8 +88,8 @@ __device__ inline void VectorizeLarsUpdate(
T* param_out, MT* velocity_out, const MT mu, MT local_lr,
const MT lars_weight_decay, const MT rescale_grad, const int tid,
const int grid_stride, const int numel, MT* master_param_out = nullptr) {
using VecType = paddle::platform::AlignedVector<T, VecSize>;
using VecMType = paddle::platform::AlignedVector<MT, VecSize>;
using VecType = phi::AlignedVector<T, VecSize>;
using VecMType = phi::AlignedVector<MT, VecSize>;
int main = numel >> (VecSize >> 1);
int tail_offset = main * VecSize;
......
......@@ -39,9 +39,9 @@ TEST(test_reduce_rank_check, all) {
}
if (is_valid) {
phi::kernels::details::CheckReduceRank(reduce_rank, rank);
phi::funcs::details::CheckReduceRank(reduce_rank, rank);
} else {
ASSERT_THROW(phi::kernels::details::CheckReduceRank(reduce_rank, rank),
ASSERT_THROW(phi::funcs::details::CheckReduceRank(reduce_rank, rank),
paddle::platform::EnforceNotMet);
}
}
......
......@@ -23,8 +23,7 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
namespace paddle {
namespace operators {
......@@ -37,7 +36,7 @@ void TensorReduceImpl(const platform::CUDADeviceContext& dev_ctx,
gpuStream_t stream) {
y->mutable_data<Ty>(x.place());
phi::kernels::TensorReduceImpl<Tx, Ty, ReduceOp, TransformOp>(
phi::funcs::TensorReduceImpl<Tx, Ty, ReduceOp, TransformOp>(
static_cast<const phi::GPUContext&>(dev_ctx), x, y, transform,
origin_reduce_dims, stream);
}
......
......@@ -16,7 +16,10 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -28,10 +31,6 @@ class SeluOp : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
return UnaryOpUnchangedInferShape(ctx);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -121,7 +120,12 @@ class SeluGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(selu, SeluInferShapeFunctor,
PT_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(selu, ops::SeluOp, ops::SeluOpMaker, ops::SeluOpInferVarType,
ops::SeluGradMaker<paddle::framework::OpDesc>,
ops::SeluGradMaker<paddle::imperative::OpBase>);
ops::SeluGradMaker<paddle::imperative::OpBase>,
SeluInferShapeFunctor);
REGISTER_OPERATOR(selu_grad, ops::SeluGradOp);
......@@ -25,8 +25,9 @@ DECLARE_bool(use_curand);
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/index_impl.cu.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#endif
namespace paddle {
......@@ -206,21 +207,21 @@ void UniformRandom(const framework::ExecutionContext& context,
if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
using MT = typename details::MPTypeTrait<T>::Type;
distribution::uniform_distribution<MT> dist;
distribution::uniform_transform<MT> trans(min, max);
distribution::distribution_and_transform<T>(dev_cxt, tensor, dist, trans);
phi::funcs::uniform_distribution<MT> dist;
phi::funcs::uniform_real_transform<MT> trans(min, max);
phi::funcs::distribution_and_transform<T>(dev_cxt, tensor, dist, trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func =
UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
diag_step, diag_val, gen_offset);
IndexKernel<T, UniformGeneratorOffset<T>>(dev_cxt, tensor, func);
phi::IndexKernel<T, UniformGeneratorOffset<T>>(dev_cxt, tensor, func);
}
} else {
auto func =
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val);
IndexKernel<T, UniformGenerator<T>>(dev_cxt, tensor, func);
phi::IndexKernel<T, UniformGenerator<T>>(dev_cxt, tensor, func);
}
}
#endif
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <cstdint>
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#define INT_BITS 32
......@@ -25,7 +25,7 @@ namespace platform {
struct FastDivMod {
// 1st value represents the result of input number divides by recorded divisor
// 2nd value represents the result of input number modulo by recorded divisor
using DivModT = AlignedVector<uint32_t, 2>;
using DivModT = phi::AlignedVector<uint32_t, 2>;
FastDivMod() {}
HOSTDEVICE FastDivMod(uint32_t d) : divisor(d) {
......
......@@ -95,7 +95,7 @@ std::unique_ptr<ProfilerResult> Profiler::Stop() {
collector.ThreadNames();
for (const auto& kv : thread_names) {
extrainfo.AddExtraInfo(string_format(std::string("%llu"), kv.first),
kv.second);
std::string("%s"), kv.second.c_str());
}
return std::unique_ptr<ProfilerResult>(
new platform::ProfilerResult(std::move(tree), extrainfo));
......
......@@ -757,7 +757,7 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
if (obj == Py_None) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"bool, but got %s",
"int, float, bool or Tensor, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
......@@ -784,7 +784,7 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"bool, but got %s",
"int, float, bool or Tensor, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
......@@ -801,7 +801,7 @@ paddle::experimental::ScalarArray CastPyArg2ScalarArray(
if (obj == Py_None) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"bool, but got %s",
"list or Tensor, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
......@@ -821,7 +821,7 @@ paddle::experimental::ScalarArray CastPyArg2ScalarArray(
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"bool, but got %s",
"list or Tensor, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
......@@ -830,5 +830,44 @@ paddle::experimental::ScalarArray CastPyArg2ScalarArray(
return paddle::experimental::ScalarArray({1});
}
paddle::experimental::Backend CastPyArg2Backend(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
if (obj == Py_None) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"int or place, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
PyTypeObject* type = obj->ob_type;
auto type_name = std::string(type->tp_name);
if (type_name == "int") {
int value = CastPyArg2Int(obj, op_type, arg_pos);
return static_cast<paddle::experimental::Backend>(value);
} else {
platform::Place place = CastPyArg2Place(obj, arg_pos);
return phi::TransToPhiBackend(place);
}
return paddle::experimental::Backend::CPU;
}
paddle::experimental::DataType CastPyArg2DataType(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
if (obj == Py_None) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"data_type, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
framework::proto::VarType::Type type = CastPyArg2ProtoType(obj, arg_pos);
return framework::TransToPhiDataType(type);
}
} // namespace pybind
} // namespace paddle
......@@ -11,6 +11,8 @@ limitations under the License. */
#pragma once
#include <Python.h>
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
......@@ -100,6 +102,14 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
paddle::experimental::ScalarArray CastPyArg2ScalarArray(
PyObject* obj, const std::string& op_type, ssize_t arg_pos);
paddle::experimental::Backend CastPyArg2Backend(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
paddle::experimental::DataType CastPyArg2DataType(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false);
......
......@@ -17,3 +17,10 @@ def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> {
OptionalAttr<DictionaryAttr>:$attrs);
let results = (outs Variadic<AnyType>);
}
def Infrt_CvtTensorOp : Infrt_Op<"cvt_tensor", [NoSideEffect]> {
let summary = "convert tensor type op";
let description = [{convert tensor type op!}];
let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output);
}
......@@ -5,5 +5,8 @@ endif()
add_subdirectory(ir)
add_subdirectory(pass)
add_executable(phi-ir-exec phi_ir_exec.cc)
target_link_libraries(phi-ir-exec infrt)
add_executable(phi-exec phi_exec.cc)
target_link_libraries(phi-exec infrt)
......@@ -3,6 +3,7 @@
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
def PHI_Dialect : Dialect {
let name = "phi";
......
......@@ -16,8 +16,10 @@
#include <glog/logging.h>
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/kernel_registry.h"
namespace infrt {
#include "paddle/phi/kernels/declarations.h"
namespace infrt {
namespace {
phi::Backend cvtTarget2Phi(TargetType target) {
switch (target) {
case TargetType::CPU:
......@@ -124,19 +126,76 @@ Place cvtPlaceFromPhi(phi::TensorArgDef tensor_arg) {
cvtLayoutFromPhi(tensor_arg.layout));
}
} // namespace
std::string getPhiTargetPrefix(TargetType target) {
switch (target) {
case TargetType::CPU:
return "phi_cpu.";
case TargetType::GPU:
return "phi_gpu.";
default:
LOG(FATAL) << "UnSupported target type !";
return std::string();
}
}
std::string getPhiPrecisionSuffix(PrecisionType precision) {
switch (precision) {
case PrecisionType::FLOAT32:
return ".float32";
case PrecisionType::FLOAT16:
return ".float16";
case PrecisionType::FLOAT64:
return ".float64";
case PrecisionType::UINT8:
return ".uint8";
case PrecisionType::INT8:
return ".int8";
case PrecisionType::INT16:
return ".int16";
case PrecisionType::INT32:
return ".int32";
case PrecisionType::INT64:
return ".int64";
case PrecisionType::COMPLEX64:
return ".complex64";
case PrecisionType::COMPLEX128:
return ".complex128";
case PrecisionType::BOOL:
return ".bool";
default:
LOG(FATAL) << "UnSupported precision type !";
return std::string();
}
}
std::string getPhiLayoutSuffix(LayoutType layout) {
switch (layout) {
case LayoutType::NCHW:
return ".nchw";
case LayoutType::NHWC:
return ".nhwc";
case LayoutType::ANY:
return ".any";
default:
LOG(FATAL) << "UnSupported layout type !";
return std::string();
}
}
std::vector<PhiKernelDesc> getCandidateKernels(
std::string name, const std::vector<Place>& valid_palces) {
std::vector<PhiKernelDesc> candidate_kernels;
PhiKernelDesc phi_kernel_desc;
phi::KernelKeyMap kernel_key_map =
phi::KernelFactory::Instance().SelectKernelMap(name);
for (const Place& place : valid_palces) {
for (Place place : valid_palces) {
phi::KernelKey kernel_key = cvtPlace2Phi(place);
if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) {
kernel_key = phi::KernelKey(kernel_key.backend(),
phi::DataLayout::ALL_LAYOUT,
kernel_key.dtype());
if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) continue;
place.layout = LayoutType::ANY;
}
phi_kernel_desc.kernelType = place;
phi_kernel_desc.inputsType.clear();
......
......@@ -26,6 +26,10 @@ struct PhiKernelDesc {
Place kernelType; // kernel place
};
std::string getPhiTargetPrefix(TargetType target);
std::string getPhiPrecisionSuffix(PrecisionType precision);
std::string getPhiLayoutSuffix(LayoutType layout);
std::vector<PhiKernelDesc> getCandidateKernels(
std::string name, const std::vector<Place>& valid_palces);
......
......@@ -18,11 +18,14 @@
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/OperationSupport.h>
#include <list>
#include <unordered_set>
#include <vector>
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
......@@ -58,8 +61,8 @@ void phiOpCvtPass::convertStage() {
continue;
}
phi::KernelSignature kernel_sign =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
::phi::KernelSignature kernel_sign =
::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
ProtoArgumentMappingContext(op));
// resort input&output according to kernel_sign
::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
......@@ -104,13 +107,92 @@ void phiOpCvtPass::diapatchStage() {
infrt::KernelOp kernel_op = ::llvm::dyn_cast_or_null<infrt::KernelOp>(&op);
if (nullptr != kernel_op) worklist.push_back(kernel_op);
}
// ToDo: implementation in the next PR
while (!worklist.empty()) {
// infrt::KernelOp kernel_op = worklist.back();
worklist.pop_back();
// std::string kernel_name = kernel_op.name().str();
// std::vector<PhiKernelDesc> candidates =
// getCandidateKernels(kernel_name, valid_places_);
mlir::OpBuilder builder(&block, block.begin());
std::map<TargetType, mlir::Value> phi_context;
for (infrt::KernelOp kernel_op : worklist) {
std::string kernel_name = kernel_op.name().str();
std::vector<PhiKernelDesc> candidates =
getCandidateKernels(kernel_name, valid_places_);
if (candidates.empty()) {
LOG(FATAL) << "No candidate kernels for op:" << kernel_name;
continue;
}
builder.setInsertionPoint(kernel_op);
// Todo: Implimentation the concrete pass pick strategy
const PhiKernelDesc &phi_kernel_desc = candidates.front();
kernel_name = getPhiTargetPrefix(phi_kernel_desc.kernelType.target) +
kernel_name +
getPhiLayoutSuffix(phi_kernel_desc.kernelType.layout) +
getPhiPrecisionSuffix(phi_kernel_desc.kernelType.precision);
// mlir::OperationName operation_name = kernel_op.getOperation()->getName();
mlir::OperationName operation_name(kernel_name, kernel_op.getContext());
mlir::OperationState operation_state(kernel_op.getLoc(), operation_name);
if (phi_context.find(phi_kernel_desc.kernelType.target) ==
phi_context.end()) {
switch (phi_kernel_desc.kernelType.target) {
case TargetType::CPU: {
auto alloctor_value =
builder
.create<infrt::phi::CreateAllocatorOp_cpu>(
kernel_op.getLoc(),
phi::AllocatorType::get(kernel_op.getContext(),
TargetType::CPU))
.output();
auto context_value =
builder
.create<infrt::phi::CreateContextOp_cpu>(
kernel_op.getLoc(),
phi::ContextType::get(kernel_op.getContext(),
TargetType::CPU),
alloctor_value)
.output();
phi_context[TargetType::CPU] = context_value;
} break;
case TargetType::GPU:
case TargetType::UNK:
default:
LOG(FATAL) << "Unsupported TargetType";
break;
}
}
operation_state.addOperands(
phi_context.at(phi_kernel_desc.kernelType.target));
for (size_t index = 0; index < phi_kernel_desc.inputsType.size(); ++index) {
mlir::Value input = kernel_op.getOperand(index);
auto cvt_tensor_type_op = builder.create<CvtTensorOp>(
kernel_op.getLoc(),
DenseTensorType::get(kernel_op.getContext(),
phi_kernel_desc.inputsType[index].target,
phi_kernel_desc.inputsType[index].precision,
phi_kernel_desc.inputsType[index].layout),
input);
operation_state.addOperands(cvt_tensor_type_op.output());
}
for (size_t index = 0; index < phi_kernel_desc.outputsType.size();
++index) {
operation_state.addTypes(
DenseTensorType::get(kernel_op.getContext(),
phi_kernel_desc.outputsType[index].target,
phi_kernel_desc.outputsType[index].precision,
phi_kernel_desc.outputsType[index].layout));
}
operation_state.addAttributes(kernel_op.attrsAttr().getValue());
mlir::Operation *phi_operation = builder.createOperation(operation_state);
for (size_t index = 0; index < phi_kernel_desc.outputsType.size();
++index) {
mlir::Value input = phi_operation->getResult(index);
auto cvt_tensor_type_op = builder.create<CvtTensorOp>(
kernel_op.getLoc(), kernel_op.getResultTypes()[index], input);
kernel_op.getResult(index).replaceAllUsesWith(
cvt_tensor_type_op.output());
}
kernel_op.erase();
}
}
} // namespace infrt
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册