提交 a6e00159 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into move_histogram_to_pten

......@@ -335,6 +335,17 @@ function(op_library TARGET)
endif()
endforeach()
# pybind USE_OP_DEVICE_KERNEL for ROCm
list (APPEND hip_srcs ${hip_cc_srcs})
# message("hip_srcs ${hip_srcs}")
foreach(hip_src ${hip_srcs})
set(op_name "")
find_register(${hip_src} "REGISTER_OP_CUDA_KERNEL" op_name)
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n")
set(pybind_flag 1)
endif()
endforeach()
# pybind USE_OP_DEVICE_KERNEL for CUDNN/MIOPEN
list(APPEND cudnn_cu_srcs ${cudnn_cu_cc_srcs})
......
......@@ -17,6 +17,12 @@ import re
import argparse
import os
# For API dispatch used at python-level
# { op_name : [arg_name, ...] }
core_ops_returns_info = {}
core_ops_args_info = {}
core_ops_args_type_info = {}
def ParseArguments():
parser = argparse.ArgumentParser(
......@@ -130,17 +136,16 @@ def ParseYamlArgs(string):
attrs_list = []
args = [x.strip() for x in string.strip().split(",")]
atype = r'((const )?\S+) '
aname = r'(\S+)'
aname = r'(.*)'
pattern = f'{atype}{aname}'
for i in range(len(args)):
arg = args[i]
m = re.search(pattern, arg)
arg_type = m.group(1)
arg_name = m.group(3).split("=")[0]
default_value = m.group(3).split("=")[1] if len(m.group(3).split(
"=")) > 1 else None
arg_type = m.group(1).strip()
arg_name = m.group(3).split("=")[0].strip()
default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None
if "Tensor" in arg_type:
assert default_value is None
inputs_list.append([arg_name, arg_type, i])
......@@ -262,7 +267,6 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]
assert orig_attr_type == forward_attr_type
assert orig_attr_default == forward_attr_default
assert orig_attr_pos == forward_attr_pos
......@@ -741,26 +745,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
# Get Function Args
num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys(
))
inputs_args_list = ["" for i in range(num_inputs)]
inputs_args_definition_list = ["" for i in range(num_inputs)]
inputs_args_declaration_list = ["" for i in range(num_inputs)]
inputs_call_list = ["" for i in range(num_inputs)]
for name, (ttype, pos) in forward_inputs_position_map.items():
inputs_call_list[pos] = f"{name}"
if IsPlainTensorType(ttype):
inputs_args_list[
inputs_args_definition_list[
pos] = f"const paddle::experimental::Tensor& {name}"
inputs_args_declaration_list[
pos] = f"const paddle::experimental::Tensor& {name}"
else:
assert IsVectorTensorType(ttype)
inputs_args_list[
inputs_args_definition_list[
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"
inputs_args_declaration_list[
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"
for name, atype, default_val, pos in forward_attrs_list:
inputs_call_list[pos] = name
if default_val is not None:
inputs_args_list[pos] = f"{atype} {name} = {default_val}"
inputs_args_declaration_list[
pos] = f"{atype} {name} = {default_val}"
else:
inputs_args_list[pos] = f"{atype} {name}"
inputs_args_declaration_list[pos] = f"{atype} {name}"
inputs_args_definition_list[pos] = f"{atype} {name}"
inputs_args_str = ", ".join(inputs_args_list)
inputs_args_declaration_str = ", ".join(inputs_args_declaration_list)
inputs_args_definition_str = ", ".join(inputs_args_definition_list)
inputs_call_args_str = ", ".join(inputs_call_list)
# Forward Full Logic
......@@ -812,13 +824,95 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
forward_function_name = GetForwardFunctionName(fwd_api_name)
forward_function_str = FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_str,
returns_type_str, forward_function_name, inputs_args_definition_str,
forward_call_str, node_creation_str, returns_str)
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_str});"
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});"
return forward_function_str, forward_function_declaration_str
def CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list):
# fwd_api_name : ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
num_args = len(forward_inputs_position_map.keys()) + len(forward_attrs_list)
num_returns = len(forward_outputs_position_map.keys())
final_state_fwd_api_name = "final_state_" + fwd_api_name
core_ops_returns_info[
final_state_fwd_api_name] = ["" for i in range(num_returns)]
core_ops_args_info[final_state_fwd_api_name] = ["" for i in range(num_args)]
core_ops_args_type_info[
final_state_fwd_api_name] = ["" for i in range(num_args)]
for name, (ttype, pos) in forward_inputs_position_map.items():
core_ops_args_info[final_state_fwd_api_name][pos] = name
if IsPlainTensorType(ttype):
core_ops_args_type_info[final_state_fwd_api_name][pos] = "tensor"
else:
assert IsVectorTensorType(ttype)
core_ops_args_type_info[final_state_fwd_api_name][pos] = "list"
for name, _, _, pos in forward_attrs_list:
core_ops_args_info[final_state_fwd_api_name][pos] = name
for name, (ttype, pos) in forward_outputs_position_map.items():
core_ops_returns_info[final_state_fwd_api_name][pos] = name
def GenerateCoreOpInfoDeclaration():
core_ops_declaration_str = """
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info;
"""
return core_ops_declaration_str
def GenerateCoreOpInfoDefinition():
CORE_OPS_INFO_TEMPLATE = """
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info = {{
{}
}};
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info = {{
{}
}};
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info = {{
{}
}};
"""
op_args_info_list = []
for op_name, arg_list in core_ops_args_info.items():
arg_str = ",".join(["\"" + v + "\"" for v in arg_list])
op_args_info = f"{{ \"{op_name}\", {{ {arg_str} }} }},"
op_args_info_list.append(op_args_info)
op_types_info_list = []
for op_name, type_list in core_ops_args_type_info.items():
type_str = ",".join(["\"" + v + "\"" for v in type_list])
op_types_info = f"{{ \"{op_name}\", {{ {type_str} }} }},"
op_types_info_list.append(op_types_info)
op_returns_info_list = []
for op_name, return_list in core_ops_returns_info.items():
return_str = ",".join(["\"" + v + "\"" for v in return_list])
return_types_info = f"{{ \"{op_name}\", {{ {return_str} }} }},"
op_returns_info_list.append(return_types_info)
op_args_info_str = "\n".join(op_args_info_list)
op_types_info_str = "\n".join(op_types_info_list)
op_returns_info_str = "\n".join(op_returns_info_list)
core_ops_info_definition_str = CORE_OPS_INFO_TEMPLATE.format(
op_args_info_str, op_types_info_str, op_returns_info_str)
return core_ops_info_definition_str
def GenerateNodeCCFile(filepath, node_definition_str):
file_contents = """
#include "glog/logging.h"
......@@ -856,6 +950,8 @@ def GenerateForwardCCFile(filepath, forward_definition_str):
#include "paddle/fluid/eager/api/utils/global_utils.h"
"""
file_contents += GenerateCoreOpInfoDefinition()
file_contents += forward_definition_str
with open(filepath, 'a') as f:
f.write(file_contents)
......@@ -871,6 +967,7 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
#include "paddle/fluid/framework/op_registry.h"
"""
file_contents += GenerateCoreOpInfoDeclaration()
file_contents += forward_function_declaration_str
with open(filepath, 'a') as f:
f.write(file_contents)
......@@ -985,6 +1082,11 @@ if __name__ == "__main__":
forward_definition_str += definition_declaration_pair[0]
forward_declaration_str += definition_declaration_pair[1]
# For python-level API dispatch
CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map,
forward_attrs_list)
# Generate Files
nodes_h_path = args.nodes_h_path
nodes_cc_path = args.nodes_cc_path
......
......@@ -104,6 +104,8 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
PyThreadState *tstate = nullptr;
try
{{
VLOG(6) << "Running Eager Final State API: {}";
// Get EagerTensors from args
{}
......@@ -129,16 +131,87 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
"""
python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
fwd_api_name, get_eager_tensor_str, parse_attributes_str,
fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str,
GetForwardFunctionName(fwd_api_name), dygraph_function_call_str)
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}"
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}},\n"
return python_c_function_str, python_c_function_reg_str
def GenerateCoreOpsInfoMap():
result = """
static PyObject * eager_get_final_state_core_ops_args_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
return ToPyObject(core_ops_final_state_args_info);
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
static PyObject * eager_get_final_state_core_ops_args_type_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
return ToPyObject(core_ops_final_state_args_type_info);
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
static PyObject * eager_get_final_state_core_ops_returns_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
return ToPyObject(core_ops_final_state_returns_info);
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
"""
core_ops_infos_registry = """
{\"get_final_state_core_ops_args_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_info, METH_NOARGS,
\"C++ interface function for eager_get_final_state_core_ops_args_info.\"},
{\"get_final_state_core_ops_args_type_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_type_info,
METH_NOARGS,
\"C++ interface function for eager_get_final_state_core_ops_args_type_info.\"},
{\"get_final_state_core_ops_returns_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_returns_info,
METH_NOARGS, \"C++ interface function for eager_get_final_state_core_ops_returns_info.\"},
"""
return result, core_ops_infos_registry
def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
core_ops_infos_definition, core_ops_infos_registry = GenerateCoreOpsInfoMap(
)
python_c_function_str += core_ops_infos_definition
python_c_function_reg_str += core_ops_infos_registry
python_c_function_reg_str += "\n {nullptr,nullptr,0,nullptr}"
PYTHON_C_WRAPPER_TEMPLATE = """
#pragma once
......@@ -215,12 +288,12 @@ if __name__ == "__main__":
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)
python_c_function_reg_list.append("{nullptr,nullptr,0,nullptr}")
python_c_functions_str = "\n".join(python_c_function_list)
python_c_functions_reg_str = ",\n".join(python_c_function_reg_list)
python_c_str = GeneratePythonCWrappers(python_c_functions_str,
python_c_functions_reg_str)
print("Generated Python-C Codes: ", python_c_str)
output_path = args.output_path
......
......@@ -22,9 +22,12 @@ limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/extension.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_kernel_info_helper.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_factory.h"
......@@ -183,14 +186,14 @@ TEST(CustomKernel, custom_kernel_dot) {
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8,
paddle::framework::make_ddim({2, 3}),
pten::framework::make_ddim({2, 3}),
pten::DataLayout::NCHW));
auto* dense_x_data =
dense_x->mutable_data<uint8_t>(paddle::platform::CPUPlace());
auto dense_y = std::make_shared<pten::DenseTensor>(
alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8,
paddle::framework::make_ddim({2, 3}),
pten::framework::make_ddim({2, 3}),
pten::DataLayout::NCHW));
auto* dense_y_data =
dense_y->mutable_data<uint8_t>(paddle::platform::CPUPlace());
......@@ -231,8 +234,7 @@ TEST(CustomKernel, custom_kernel_dot) {
pten::DataType fake_attr_dtype = pten::DataType::UINT32;
paddle::framework::LoDTensor tmp_tensor;
tmp_tensor.mutable_data<uint8_t>({1}, pten::TransToPtenPlace(backend));
pten::Scalar fake_attr_scalar =
paddle::experimental::MakePtenScalar(tmp_tensor);
pten::Scalar fake_attr_scalar{tmp_tensor};
pten::ScalarArray fake_attr_scalar_array;
std::vector<int64_t> fake_attr_int64_vec;
std::vector<int> fake_attr_int_vec;
......
......@@ -41,6 +41,10 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext {
return ctx_.HasOutput(name);
}
bool HasAttr(const std::string& name) const override {
return ctx_.HasAttr(name);
}
paddle::any Attr(const std::string& name) const override {
auto& attr = ctx_.Attrs().GetAttr(name);
return GetAttrValue(attr);
......@@ -278,21 +282,47 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
pten::InferMetaContext infer_meta_context(ctx->IsRuntime());
auto& input_names = std::get<0>(signature.args);
auto& attr_names = std::get<1>(signature.args);
auto& output_names = std::get<2>(signature.args);
// TODO(chenweihang): support attrs in next pr
// auto& attr_names = std::get<1>(signature.args);
// TODO(chenweihang): support multiple inputs and outputs
// TODO(chenweihang): support multiple inputs and outputs later
pten::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) {
if (ctx->HasInput(in_name)) {
infer_meta_context.EmplaceBackInput(std::make_shared<CompatMetaTensor>(
ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackInput({nullptr});
}
}
auto attr_reader = ctx->Attrs();
for (auto& attr_name : attr_names) {
if (ctx->HasAttr(attr_name)) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(bool))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(float))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else {
// do nothing, skip useless attrs now
// TODO(chenweihang): support other attr type later and throw error
// if attr is cannot parsed
}
} else {
// do nothing
}
}
for (auto& out_name : output_names) {
if (ctx->HasOutput(out_name)) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
}
// TODO(chenweihang): support attrs later
return infer_meta_context;
}
......
......@@ -475,12 +475,11 @@ void InterpreterCore::ExecuteInstructionList(
if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(1) << "Exception caught " << exception_holder_.Type();
// NOTE(xiongkun) Why we reset ?
// The caught exception may be EOFExcetion, under this situation, we need
// make async_work_queue_ available, so we need reset.
// Graceful exit when the executor encountered a fatal error.
// EOF is not a fatal error.
if (exception_holder_.Type() != "EOF") {
async_work_queue_->Cancel();
async_work_queue_.reset(new interpreter::AsyncWorkQueue(
kHostNumThreads, &main_thread_blocker_));
}
PADDLE_ENFORCE_EQ(
main_thread_blocker_.Clear(), 0,
platform::errors::PreconditionNotMet(
......
......@@ -74,6 +74,10 @@ bool InterpretercoreInferShapeContext::HasOutput(
return out[0] != nullptr;
}
bool InterpretercoreInferShapeContext::HasAttr(const std::string& name) const {
return op_.HasAttr(name);
}
bool InterpretercoreInferShapeContext::HasInputs(
const std::string& name) const {
const auto& ins = ctx_.inputs;
......
......@@ -54,6 +54,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string& name) const override;
bool HasAttr(const std::string& name) const override;
bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name) const override;
......
......@@ -35,6 +35,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string &name) const override;
bool HasAttr(const std::string &name) const override;
bool HasInputs(const std::string &name) const override;
bool HasOutputs(const std::string &name) const override;
......@@ -855,6 +857,10 @@ bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
return block_.HasVarRecursive(output_names[0]);
}
bool CompileTimeInferShapeContext::HasAttr(const std::string &name) const {
return op_.HasAttr(name);
}
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
if (op_.Inputs().find(name) == op_.Inputs().end()) {
return false;
......
......@@ -664,6 +664,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
return out[0] != nullptr;
}
bool HasAttr(const std::string& name) const override {
return op_.HasAttr(name);
}
bool HasInputs(const std::string& name) const override {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
......@@ -2099,6 +2103,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
std::type_index(typeid(std::vector<int32_t>))) {
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(int32_t))) {
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
&BOOST_GET_CONST(int32_t, attr_iter->second), 1)));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
......
......@@ -455,6 +455,10 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
return ctx_.HasOutput(name);
}
bool HasAttr(const std::string& name) const override {
return ctx_.HasAttr(name);
}
paddle::any Attr(const std::string& name) const override {
auto& attr = ctx_.GetAttr(name);
return GetAttrValue(attr);
......
......@@ -61,6 +61,7 @@ class InferShapeContext {
virtual ~InferShapeContext() = default;
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;
virtual bool HasAttr(const std::string &name) const = 0;
virtual std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const = 0;
......
......@@ -78,6 +78,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return out[0] != nullptr;
}
bool HasAttr(const std::string& name) const override {
return attrs_->count(name) > 0 || default_attrs_->count(name) > 0;
}
bool HasInputs(const std::string& name) const override {
auto it = var_map_in_->find(name);
if (it == var_map_in_->end() || it->second.empty()) {
......
......@@ -346,6 +346,14 @@ void BuildDygraphPtenKernelContext(
std::type_index(typeid(std::vector<int32_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int64_t))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::ScalarArray(&BOOST_GET_CONST(int64_t, attr), 1)));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int32_t))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::ScalarArray(&BOOST_GET_CONST(int32_t, attr), 1)));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
......
......@@ -217,7 +217,7 @@ TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) {
} // namespace imperative
} // namespace paddle
USE_OP(split);
USE_OP_ITSELF(split);
USE_OP(relu);
#ifdef PADDLE_WITH_MKLDNN
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
#if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/mlu/cncl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class CBroadcastOPMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_CNCL)
auto x = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
int numel = x->numel();
cnclDataType_t dtype = platform::ToCNCLDataType(x->type());
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = platform::CNCLCommContext::Instance().Get(rid, place);
mluStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::MLUDeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
int root = ctx.Attr<int>("root");
if (root == comm->rank()) {
PADDLE_ENFORCE_MLU_SUCCESS(
cnclBcast(reinterpret_cast<void*>(const_cast<T*>(x->data<T>())),
numel, dtype, root, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent "
<< x->numel();
if (out != x) {
framework::TensorCopy(
*static_cast<const framework::Tensor*>(x), place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<framework::Tensor*>(out));
}
} else {
PADDLE_ENFORCE_MLU_SUCCESS(cnclBcast(out->mutable_data<T>(place), numel,
dtype, root, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved "
<< framework::product(out->dims());
}
out->Resize(x->dims());
out->set_lod(x->lod());
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with MLU."));
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(c_broadcast, ops::CBroadcastOPMLUKernel<float>,
ops::CBroadcastOPMLUKernel<plat::float16>,
ops::CBroadcastOPMLUKernel<int>,
ops::CBroadcastOPMLUKernel<int16_t>,
ops::CBroadcastOPMLUKernel<int8_t>,
ops::CBroadcastOPMLUKernel<uint8_t>);
......@@ -47,8 +47,12 @@ class GatherNdXPUKernel : public framework::OpKernel<T> {
auto x_shape = paddle::framework::vectorize<int>(x->dims());
auto index_shape = paddle::framework::vectorize<int>(index->dims());
if (index_shape.size() == 1) {
index_shape.insert(index_shape.begin(), 1);
}
xpu::VectorParam<int> x_vec = {x_shape.data(),
static_cast<int>(x_shape.size()), nullptr};
auto &dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = XPU_SUCCESS;
......
......@@ -16,6 +16,10 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/backward.h"
namespace paddle {
namespace operators {
......@@ -343,25 +347,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul_v2");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul_v2");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "matmul_v2");
auto x_dims = context->GetInputDim("X");
auto y_dims = context->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (context->HasOutput(x_grad_name)) {
context->SetOutputDim(x_grad_name, x_dims);
}
if (context->HasOutput(y_grad_name)) {
context->SetOutputDim(y_grad_name, y_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
......@@ -539,9 +524,12 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker,
ops::MatMulV2GradOpMaker<paddle::framework::OpDesc>,
ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>);
DELCARE_INFER_SHAPE_FUNCTOR(matmul_v2_grad, MatMulV2GradInferShapeFunctor,
PT_INFER_META(pten::MatmulGradInferMeta));
REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad,
ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>);
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>,
MatMulV2GradInferShapeFunctor);
REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad,
ops::MatMulV2OpTripleGradMaker<paddle::framework::OpDesc>,
......
......@@ -172,11 +172,3 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker,
ops::SplitGradMaker<paddle::framework::OpDesc>,
ops::SplitGradMaker<paddle::imperative::OpBase>);
namespace plat = paddle::platform;
REGISTER_OP_CPU_KERNEL(
split, ops::SplitOpKernel<plat::CPUDeviceContext, double>,
ops::SplitOpKernel<plat::CPUDeviceContext, float>,
ops::SplitOpKernel<plat::CPUDeviceContext, int64_t>,
ops::SplitOpKernel<plat::CPUDeviceContext, int>,
ops::SplitOpKernel<plat::CPUDeviceContext, bool>,
ops::SplitOpKernel<plat::CPUDeviceContext, plat::float16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/split_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
split, ops::SplitOpKernel<plat::CUDADeviceContext, double>,
ops::SplitOpKernel<plat::CUDADeviceContext, float>,
ops::SplitOpKernel<plat::CUDADeviceContext, int64_t>,
ops::SplitOpKernel<plat::CUDADeviceContext, int>,
ops::SplitOpKernel<plat::CUDADeviceContext, bool>,
ops::SplitOpKernel<plat::CUDADeviceContext, plat::float16>,
ops::SplitOpKernel<plat::CUDADeviceContext, plat::bfloat16>);
......@@ -19,10 +19,8 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/pten/kernels/split_kernel.h"
namespace paddle {
namespace operators {
static inline std::vector<framework::DDim> UpdateOutsDims(
......@@ -108,56 +106,6 @@ static inline std::vector<framework::DDim> UpdateOutsDims(
}
return outs_dims;
}
template <typename DeviceContext, typename T>
class SplitOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
int num = ctx.Attr<int>("num");
std::vector<int> sections = ctx.Attr<std::vector<int>>("sections");
int axis = ctx.Attr<int>("axis");
auto in_dims = in->dims();
auto outs_number = outs.size();
bool need_resize_outs_dims = false;
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<framework::Tensor>("AxisTensor");
axis = GetDataFromTensor(axis_tensor)[0];
need_resize_outs_dims = true;
}
auto sections_tensor_list =
ctx.MultiInput<framework::Tensor>("SectionsTensorList");
if (sections_tensor_list.size() > 0) {
sections = GetDataFromTensorList(sections_tensor_list);
need_resize_outs_dims = true;
}
if (need_resize_outs_dims) {
std::vector<framework::DDim> outs_dims =
UpdateOutsDims(true, true, in_dims, num, sections, axis, outs_number);
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->Resize(outs_dims[j]);
}
}
std::vector<const framework::Tensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
shape_refer.emplace_back(outs[j]);
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
StridedMemcpyWithAxis0<T>(dev_ctx, *in, shape_refer, &outs);
} else {
math::SplitFunctor<DeviceContext, T> functor;
functor(dev_ctx, *in, shape_refer, axis, &outs);
}
}
};
template <typename T>
class SplitGradMaker : public framework::SingleGradOpMaker<T> {
......
cc_library(host_tracer SRCS host_tracer.cc DEPS enforce)
cc_library(new_profiler SRCS profiler.cc DEPS host_tracer)
cc_library(cuda_tracer SRCS cuda_tracer.cc cupti_data_process.cc DEPS workqueue_utils enforce glog)
cc_library(new_profiler SRCS profiler.cc DEPS host_tracer cuda_tracer)
cc_library(event_node SRCS event_node.cc DEPS enforce)
cc_library(chrometracinglogger SRCS chrometracing_logger.cc DEPS event_node)
cc_test(test_event_node SRCS test_event_node.cc DEPS event_node chrometracinglogger)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/profiler/cuda_tracer.h"
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
#include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/cupti_data_process.h"
#define CUPTI_CALL(call) \
do { \
CUptiResult _status = call; \
if (_status != CUPTI_SUCCESS) { \
const char* errstr; \
dynload::cuptiGetResultString(_status, &errstr); \
LOG(ERROR) << "Function " << #call << " failed with error " << errstr; \
exit(-1); \
} \
} while (0)
namespace paddle {
namespace platform {
namespace details {
std::unordered_map<uint32_t, uint64_t> CreateThreadIdMapping() {
std::unordered_map<uint32_t, uint64_t> mapping;
std::unordered_map<uint64_t, ThreadId> ids = GetAllThreadIds();
for (const auto& id : ids) {
mapping[id.second.cupti_tid] = id.second.sys_tid;
}
return mapping;
}
} // namespace details
CudaTracer::CudaTracer() {}
void CudaTracer::PrepareTracing() {
PADDLE_ENFORCE_EQ(
state_ == TracerState::UNINITED || state_ == TracerState::STOPED, true,
platform::errors::PreconditionNotMet("Tracer must be UNINITED"));
EnableCuptiActivity();
state_ = TracerState::READY;
}
void CudaTracer::StartTracing() {
PADDLE_ENFORCE_EQ(
state_ == TracerState::READY, true,
platform::errors::PreconditionNotMet("Tracer must be READY or STOPPED"));
ConsumeBuffers();
tracing_start_ns_ = PosixInNsec();
state_ = TracerState::STARTED;
}
void CudaTracer::StopTracing() {
PADDLE_ENFORCE_EQ(
state_, TracerState::STARTED,
platform::errors::PreconditionNotMet("Tracer must be STARTED"));
DisableCuptiActivity();
state_ = TracerState::STOPED;
}
void CudaTracer::CollectTraceData(TraceEventCollector* collector) {
PADDLE_ENFORCE_EQ(
state_, TracerState::STOPED,
platform::errors::PreconditionNotMet("Tracer must be STOPED"));
ProcessCuptiActivity(collector);
}
int CudaTracer::ProcessCuptiActivity(TraceEventCollector* collector) {
int record_cnt = 0;
#ifdef PADDLE_WITH_CUPTI
CUPTI_CALL(dynload::cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED));
auto mapping = details::CreateThreadIdMapping();
std::vector<ActivityBuffer> buffers = ConsumeBuffers();
for (auto& buffer : buffers) {
if (buffer.addr == nullptr || buffer.valid_size == 0) {
continue;
}
CUpti_Activity* record = nullptr;
while (true) {
CUptiResult status = dynload::cuptiActivityGetNextRecord(
buffer.addr, buffer.valid_size, &record);
if (status == CUPTI_SUCCESS) {
details::ProcessCuptiActivityRecord(record, tracing_start_ns_, mapping,
collector);
++record_cnt;
} else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
break;
} else {
CUPTI_CALL(status);
}
}
ReleaseBuffer(buffer.addr);
}
#endif
return record_cnt;
}
void CudaTracer::EnableCuptiActivity() {
#ifdef PADDLE_WITH_CUPTI
CUPTI_CALL(dynload::cuptiActivityRegisterCallbacks(BufferRequestedCallback,
BufferCompletedCallback));
CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY));
CUPTI_CALL(
dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL));
CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER));
CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME));
CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET));
VLOG(3) << "enable cupti activity";
#endif
}
void CudaTracer::DisableCuptiActivity() {
#ifdef PADDLE_WITH_CUPTI
CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY));
CUPTI_CALL(
dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL));
CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DRIVER));
CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_RUNTIME));
CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET));
VLOG(3) << "disable cupti activity";
#endif
}
#ifdef PADDLE_WITH_CUPTI
void CUPTIAPI CudaTracer::BufferRequestedCallback(uint8_t** buffer,
size_t* size,
size_t* max_num_records) {
GetInstance().AllocateBuffer(buffer, size);
*max_num_records = 0;
}
void CUPTIAPI CudaTracer::BufferCompletedCallback(CUcontext ctx,
uint32_t stream_id,
uint8_t* buffer, size_t size,
size_t valid_size) {
GetInstance().ProduceBuffer(buffer, valid_size);
size_t dropped = 0;
CUPTI_CALL(
dynload::cuptiActivityGetNumDroppedRecords(ctx, stream_id, &dropped));
if (dropped != 0) {
LOG(WARNING) << "Stream " << stream_id << " Dropped " << dropped
<< " activity records";
}
}
#endif
void CudaTracer::AllocateBuffer(uint8_t** buffer, size_t* size) {
constexpr size_t kBufSize = 1 << 23; // 8 MB
constexpr size_t kBufAlign = 8; // 8 B
*buffer = reinterpret_cast<uint8_t*>(
paddle::framework::AlignedMalloc(kBufSize, kBufAlign));
*size = kBufSize;
}
void CudaTracer::ProduceBuffer(uint8_t* buffer, size_t valid_size) {
std::lock_guard<std::mutex> guard(activity_buffer_lock_);
activity_buffers_.emplace_back(buffer, valid_size);
}
std::vector<CudaTracer::ActivityBuffer> CudaTracer::ConsumeBuffers() {
std::vector<ActivityBuffer> buffers;
{
std::lock_guard<std::mutex> guard(activity_buffer_lock_);
buffers.swap(activity_buffers_);
}
return buffers;
}
void CudaTracer::ReleaseBuffer(uint8_t* buffer) {
paddle::framework::AlignedFree(buffer);
}
} // namespace platform
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <mutex>
#include <vector>
#include "paddle/fluid/platform/dynload/cupti.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/profiler/tracer_base.h"
namespace paddle {
namespace platform {
// Based on CUDA CUPTI
class CudaTracer : public TracerBase {
public:
// Singleton. CUPTI imposes this restriction.
static CudaTracer& GetInstance() {
static CudaTracer instance;
return instance;
}
void PrepareTracing() override;
void StartTracing() override;
void StopTracing() override;
void CollectTraceData(TraceEventCollector* collector) override;
private:
struct ActivityBuffer {
ActivityBuffer(uint8_t* addr, size_t size) : addr(addr), valid_size(size) {}
uint8_t* addr;
size_t valid_size;
};
CudaTracer();
DISABLE_COPY_AND_ASSIGN(CudaTracer);
void EnableCuptiActivity();
void DisableCuptiActivity();
int ProcessCuptiActivity(TraceEventCollector* collector);
#ifdef PADDLE_WITH_CUPTI
// Used by CUPTI Activity API to request buffer
static void CUPTIAPI BufferRequestedCallback(uint8_t** buffer, size_t* size,
size_t* max_num_records);
// Used by CUPTI Activity API to commit a completed buffer
static void CUPTIAPI BufferCompletedCallback(CUcontext ctx,
uint32_t stream_id,
uint8_t* buffer, size_t size,
size_t valid_size);
#endif
void AllocateBuffer(uint8_t** buffer, size_t* size);
void ProduceBuffer(uint8_t* buffer, size_t valid_size);
std::vector<ActivityBuffer> ConsumeBuffers();
void ReleaseBuffer(uint8_t* buffer);
uint64_t tracing_start_ns_ = UINT64_MAX;
std::mutex activity_buffer_lock_;
std::vector<ActivityBuffer> activity_buffers_;
};
} // namespace platform
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/profiler/cupti_data_process.h"
#include <cstdio>
#include "paddle/fluid/platform/os_info.h"
namespace paddle {
namespace platform {
namespace details {
#ifdef PADDLE_WITH_CUPTI
void AddKernelRecord(const CUpti_ActivityKernel4* kernel, uint64_t start_ns,
TraceEventCollector* collector) {
if (kernel->start < start_ns) {
return;
}
DeviceTraceEvent event;
event.name = kernel->name;
event.type = TracerEventType::Kernel;
event.start_ns = kernel->start;
event.end_ns = kernel->end;
event.device_id = kernel->deviceId;
event.context_id = kernel->contextId;
event.stream_id = kernel->streamId;
event.correlation_id = kernel->correlationId;
event.kernel_info.block_x = kernel->blockX;
event.kernel_info.block_y = kernel->blockY;
event.kernel_info.block_z = kernel->blockZ;
event.kernel_info.grid_x = kernel->gridX;
event.kernel_info.grid_y = kernel->gridY;
event.kernel_info.grid_z = kernel->gridZ;
event.kernel_info.dynamic_shared_memory = kernel->dynamicSharedMemory;
event.kernel_info.static_shared_memory = kernel->staticSharedMemory;
event.kernel_info.registers_per_thread = kernel->registersPerThread;
event.kernel_info.local_memory_per_thread = kernel->localMemoryPerThread;
event.kernel_info.local_memory_total = kernel->localMemoryTotal;
event.kernel_info.queued = kernel->queued;
event.kernel_info.submitted = kernel->submitted;
event.kernel_info.completed = kernel->completed;
collector->AddDeviceEvent(std::move(event));
}
const char* MemcpyKind(uint8_t kind) {
switch (kind) {
case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD:
return "MEMCPY_HtoD";
case CUPTI_ACTIVITY_MEMCPY_KIND_DTOH:
return "MEMCPY_DtoH";
case CUPTI_ACTIVITY_MEMCPY_KIND_HTOA:
return "MEMCPY_HtoA";
case CUPTI_ACTIVITY_MEMCPY_KIND_ATOH:
return "MEMCPY_AtoH";
case CUPTI_ACTIVITY_MEMCPY_KIND_ATOA:
return "MEMCPY_AtoA";
case CUPTI_ACTIVITY_MEMCPY_KIND_ATOD:
return "MEMCPY_AtoD";
case CUPTI_ACTIVITY_MEMCPY_KIND_DTOA:
return "MEMCPY_DtoA";
case CUPTI_ACTIVITY_MEMCPY_KIND_DTOD:
return "MEMCPY_DtoD";
case CUPTI_ACTIVITY_MEMCPY_KIND_HTOH:
return "MEMCPY_HtoH";
case CUPTI_ACTIVITY_MEMCPY_KIND_PTOP:
return "MEMCPY_PtoP";
default:
return "MEMCPY";
}
}
const char* MemoryKind(uint16_t kind) {
switch (kind) {
case CUPTI_ACTIVITY_MEMORY_KIND_UNKNOWN:
return "Unknown";
case CUPTI_ACTIVITY_MEMORY_KIND_PAGEABLE:
return "Pageable";
case CUPTI_ACTIVITY_MEMORY_KIND_PINNED:
return "Pinned";
case CUPTI_ACTIVITY_MEMORY_KIND_DEVICE:
return "Device";
case CUPTI_ACTIVITY_MEMORY_KIND_ARRAY:
return "Array";
case CUPTI_ACTIVITY_MEMORY_KIND_MANAGED:
return "Managed";
case CUPTI_ACTIVITY_MEMORY_KIND_DEVICE_STATIC:
return "Device Static";
case CUPTI_ACTIVITY_MEMORY_KIND_MANAGED_STATIC:
return "Managed Static";
default:
return "Unknown";
}
}
void AddMemcpyRecord(const CUpti_ActivityMemcpy* memcpy, uint64_t start_ns,
TraceEventCollector* collector) {
if (memcpy->start < start_ns) {
return;
}
DeviceTraceEvent event;
event.name = MemcpyKind(memcpy->copyKind);
event.type = TracerEventType::Memcpy;
event.start_ns = memcpy->start;
event.end_ns = memcpy->end;
event.device_id = memcpy->deviceId;
event.context_id = memcpy->contextId;
event.stream_id = memcpy->streamId;
event.correlation_id = memcpy->correlationId;
event.memcpy_info.num_bytes = memcpy->bytes;
// snprintf(event.memcpy_info.copy_kind, kMemKindMaxLen, "%s",
// MemcpyKind(memcpy->copyKind));
snprintf(event.memcpy_info.src_kind, kMemKindMaxLen, "%s",
MemcpyKind(memcpy->srcKind));
snprintf(event.memcpy_info.dst_kind, kMemKindMaxLen, "%s",
MemcpyKind(memcpy->dstKind));
collector->AddDeviceEvent(std::move(event));
}
void AddMemcpy2Record(const CUpti_ActivityMemcpy2* memcpy2, uint64_t start_ns,
TraceEventCollector* collector) {
if (memcpy2->start < start_ns) {
return;
}
DeviceTraceEvent event;
event.name = MemcpyKind(memcpy2->copyKind);
event.type = TracerEventType::Memcpy;
event.start_ns = memcpy2->start;
event.end_ns = memcpy2->end;
event.device_id = memcpy2->deviceId;
event.context_id = memcpy2->contextId;
event.stream_id = memcpy2->streamId;
event.correlation_id = memcpy2->correlationId;
event.memcpy_info.num_bytes = memcpy2->bytes;
// snprintf(event.memcpy_info.copy_kind, kMemKindMaxLen, "%s",
// MemcpyKind(memcpy2->copyKind));
snprintf(event.memcpy_info.src_kind, kMemKindMaxLen, "%s",
MemcpyKind(memcpy2->srcKind));
snprintf(event.memcpy_info.dst_kind, kMemKindMaxLen, "%s",
MemcpyKind(memcpy2->dstKind));
collector->AddDeviceEvent(std::move(event));
}
void AddMemsetRecord(const CUpti_ActivityMemset* memset, uint64_t start_ns,
TraceEventCollector* collector) {
if (memset->start < start_ns) {
return;
}
DeviceTraceEvent event;
event.name = "MEMSET";
event.type = TracerEventType::Memset;
event.start_ns = memset->start;
event.end_ns = memset->end;
event.device_id = memset->deviceId;
event.context_id = memset->contextId;
event.stream_id = memset->streamId;
event.correlation_id = memset->correlationId;
event.memset_info.num_bytes = memset->bytes;
snprintf(event.memset_info.memory_kind, kMemKindMaxLen, "%s",
MemoryKind(memset->memoryKind));
event.memset_info.value = memset->value;
collector->AddDeviceEvent(std::move(event));
}
class CuptiRuntimeCbidStr {
public:
static const CuptiRuntimeCbidStr& GetInstance() {
static CuptiRuntimeCbidStr inst;
return inst;
}
std::string RuntimeKind(CUpti_CallbackId cbid) const {
auto iter = cbid_str_.find(cbid);
if (iter == cbid_str_.end()) {
return "Runtime API " + std::to_string(cbid);
}
return iter->second;
}
private:
CuptiRuntimeCbidStr();
std::unordered_map<CUpti_CallbackId, std::string> cbid_str_;
};
CuptiRuntimeCbidStr::CuptiRuntimeCbidStr() {
#define REGISTER_RUNTIME_CBID_STR(cbid) \
cbid_str_[CUPTI_RUNTIME_TRACE_CBID_##cbid] = #cbid
REGISTER_RUNTIME_CBID_STR(cudaBindTexture_v3020);
REGISTER_RUNTIME_CBID_STR(cudaConfigureCall_v3020);
REGISTER_RUNTIME_CBID_STR(cudaDeviceGetAttribute_v5000);
REGISTER_RUNTIME_CBID_STR(cudaDeviceGetStreamPriorityRange_v5050);
REGISTER_RUNTIME_CBID_STR(cudaDeviceSynchronize_v3020);
REGISTER_RUNTIME_CBID_STR(cudaDriverGetVersion_v3020);
REGISTER_RUNTIME_CBID_STR(cudaEventCreateWithFlags_v3020);
REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020);
REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020);
REGISTER_RUNTIME_CBID_STR(cudaEventQuery_v3020);
REGISTER_RUNTIME_CBID_STR(cudaEventRecord_v3020);
REGISTER_RUNTIME_CBID_STR(cudaFreeHost_v3020);
REGISTER_RUNTIME_CBID_STR(cudaFree_v3020);
REGISTER_RUNTIME_CBID_STR(cudaFuncGetAttributes_v3020);
REGISTER_RUNTIME_CBID_STR(cudaGetDeviceCount_v3020);
REGISTER_RUNTIME_CBID_STR(cudaGetDeviceProperties_v3020);
REGISTER_RUNTIME_CBID_STR(cudaGetDevice_v3020);
REGISTER_RUNTIME_CBID_STR(cudaGetErrorString_v3020);
REGISTER_RUNTIME_CBID_STR(cudaGetLastError_v3020);
REGISTER_RUNTIME_CBID_STR(cudaHostAlloc_v3020);
REGISTER_RUNTIME_CBID_STR(cudaHostGetDevicePointer_v3020);
REGISTER_RUNTIME_CBID_STR(cudaLaunchKernel_v7000);
REGISTER_RUNTIME_CBID_STR(cudaMallocHost_v3020);
REGISTER_RUNTIME_CBID_STR(cudaMalloc_v3020);
REGISTER_RUNTIME_CBID_STR(cudaMemcpyAsync_v3020);
REGISTER_RUNTIME_CBID_STR(cudaMemcpy_v3020);
REGISTER_RUNTIME_CBID_STR(cudaMemsetAsync_v3020);
REGISTER_RUNTIME_CBID_STR(cudaMemset_v3020);
REGISTER_RUNTIME_CBID_STR(
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags_v7000);
REGISTER_RUNTIME_CBID_STR(cudaPeekAtLastError_v3020);
REGISTER_RUNTIME_CBID_STR(cudaRuntimeGetVersion_v3020);
REGISTER_RUNTIME_CBID_STR(cudaSetDevice_v3020);
REGISTER_RUNTIME_CBID_STR(cudaStreamCreate_v3020);
REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithFlags_v5000);
REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithPriority_v5050);
REGISTER_RUNTIME_CBID_STR(cudaStreamDestroy_v5050);
REGISTER_RUNTIME_CBID_STR(cudaStreamSynchronize_v3020);
REGISTER_RUNTIME_CBID_STR(cudaStreamWaitEvent_v3020);
REGISTER_RUNTIME_CBID_STR(cudaUnbindTexture_v3020);
REGISTER_RUNTIME_CBID_STR(cudaSetupArgument_v3020);
REGISTER_RUNTIME_CBID_STR(cudaLaunch_v3020);
REGISTER_RUNTIME_CBID_STR(cudaDeviceGetPCIBusId_v4010);
#if CUDA_VERSION >= 9000
REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernel_v9000);
REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernelMultiDevice_v9000);
#endif
#undef REGISTER_RUNTIME_CBID_STR
}
void AddApiRecord(const CUpti_ActivityAPI* api, uint64_t start_ns,
const std::unordered_map<uint32_t, uint64_t> tid_mapping,
TraceEventCollector* collector) {
if (api->start < start_ns) {
return;
}
RuntimeTraceEvent event;
event.name = CuptiRuntimeCbidStr::GetInstance().RuntimeKind(api->cbid);
event.start_ns = api->start;
event.end_ns = api->end;
event.process_id = GetProcessId();
uint64_t tid = 0;
auto iter = tid_mapping.find(api->threadId);
if (iter == tid_mapping.end()) {
} else {
tid = iter->second;
}
event.thread_id = tid;
event.correlation_id = api->correlationId;
event.callback_id = api->cbid;
collector->AddRuntimeEvent(std::move(event));
}
void ProcessCuptiActivityRecord(
const CUpti_Activity* record, uint64_t start_ns,
const std::unordered_map<uint32_t, uint64_t> tid_mapping,
TraceEventCollector* collector) {
switch (record->kind) {
case CUPTI_ACTIVITY_KIND_KERNEL:
case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL:
AddKernelRecord(reinterpret_cast<const CUpti_ActivityKernel4*>(record),
start_ns, collector);
break;
case CUPTI_ACTIVITY_KIND_MEMCPY:
AddMemcpyRecord(reinterpret_cast<const CUpti_ActivityMemcpy*>(record),
start_ns, collector);
break;
case CUPTI_ACTIVITY_KIND_MEMCPY2:
AddMemcpy2Record(reinterpret_cast<const CUpti_ActivityMemcpy2*>(record),
start_ns, collector);
break;
case CUPTI_ACTIVITY_KIND_MEMSET:
AddMemsetRecord(reinterpret_cast<const CUpti_ActivityMemset*>(record),
start_ns, collector);
break;
case CUPTI_ACTIVITY_KIND_DRIVER:
case CUPTI_ACTIVITY_KIND_RUNTIME:
AddApiRecord(reinterpret_cast<const CUpti_ActivityAPI*>(record), start_ns,
tid_mapping, collector);
break;
default:
break;
}
}
#endif
} // namespace details
} // namespace platform
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#include "paddle/fluid/platform/dynload/cupti.h"
#include "paddle/fluid/platform/profiler/trace_event_collector.h"
namespace paddle {
namespace platform {
namespace details {
#ifdef PADDLE_WITH_CUPTI
void ProcessCuptiActivityRecord(
const CUpti_Activity* record, uint64_t start_ns,
const std::unordered_map<uint32_t, uint64_t> tid_mapping,
TraceEventCollector* collector);
#endif
} // namespace details
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/profiler/host_tracer.h"
#include "glog/logging.h"
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/profiler/profiler.h"
#include "glog/logging.h"
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#endif
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler/cuda_tracer.h"
#include "paddle/fluid/platform/profiler/host_tracer.h"
#include "paddle/fluid/platform/profiler/trace_event_collector.h"
......@@ -46,6 +47,7 @@ Profiler::Profiler(const ProfilerOptions& options) {
HostTracerOptions host_tracer_options;
host_tracer_options.trace_level = options.trace_level;
tracers_.emplace_back(new HostTracer(host_tracer_options), true);
tracers_.emplace_back(&CudaTracer::GetInstance(), false);
}
Profiler::~Profiler() { alive_.store(false); }
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <set>
#include <string>
......@@ -44,10 +44,44 @@ TEST(ProfilerTest, TestHostTracer) {
}
auto nodetree = profiler->Stop();
std::set<std::string> host_events;
for (const auto pair : nodetree->Traverse(true))
for (const auto pair : nodetree->Traverse(true)) {
for (const auto evt : pair.second) {
host_events.insert(evt->Name());
}
}
EXPECT_EQ(host_events.count("TestTraceLevel_record1"), 1u);
EXPECT_EQ(host_events.count("TestTraceLevel_record2"), 0u);
}
TEST(ProfilerTest, TestCudaTracer) {
using paddle::platform::ProfilerOptions;
using paddle::platform::Profiler;
ProfilerOptions options;
options.trace_level = 0;
auto profiler = Profiler::Create(options);
EXPECT_TRUE(profiler);
profiler->Prepare();
profiler->Start();
#ifdef PADDLE_WITH_CUDA
cudaStream_t stream;
cudaStreamCreate(&stream);
cudaStreamSynchronize(stream);
#endif
#ifdef PADDLE_WITH_HIP
hipStream_t stream;
hipStreamCreate(&stream);
hipStreamSynchronize(stream);
#endif
auto nodetree = profiler->Stop();
std::vector<std::string> runtime_events;
for (const auto pair : nodetree->Traverse(true)) {
for (const auto host_node : pair.second) {
for (auto runtime_node : host_node->GetRuntimeTraceEventNodes()) {
runtime_events.push_back(runtime_node->Name());
}
}
}
#ifdef PADDLE_WITH_CUPTI
EXPECT_GT(runtime_events.size(), 0u);
#endif
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
......
......@@ -506,7 +506,7 @@ PyObject* ToPyObject(const paddle::framework::proto::VarType& type) {
}
PyObject* ToPyObject(const paddle::framework::LoDTensor* value) {
auto obj = ::pybind11::cast(value, py::return_value_policy::copy);
auto obj = ::pybind11::cast(value, py::return_value_policy::reference);
obj.inc_ref();
return obj.ptr();
}
......
......@@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/common/backend.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
/**
* This file stores some special APIs that are implemented manually
......@@ -28,5 +30,11 @@ namespace experimental {
// TODO(chenweihang): Replace backend by place when place is ready
PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking);
// TODO(chentianyu03): Split API has extra logic to calculate the outputs size,
// api_gen do not support
PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis);
} // namespace experimental
} // namespace paddle
......@@ -19,9 +19,12 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/infermeta/unary.h"
PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
......@@ -75,6 +78,71 @@ PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
return out;
}
PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"split", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "split API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "split API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto dense_x = PrepareData(x, kernel.InputAt(0), {});
// Calculate the number of out tensors
size_t out_number;
if (num_or_sections.GetData().size() == 1) {
out_number = num_or_sections.GetData()[0];
} else {
out_number = num_or_sections.GetData().size();
}
std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);
std::vector<pten::MetaTensor> meta_outs;
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(dense_outs[i]);
}
pten::SplitInferMeta(
MakeMetaTensor(*dense_x), num_or_sections, axis, &meta_outs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const pten::DenseTensor&,
const pten::ScalarArray&,
const pten::Scalar&,
std::vector<pten::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*dense_x,
pten::ScalarArray(num_or_sections),
pten::Scalar(axis),
dense_outs);
return out;
}
} // namespace experimental
} // namespace paddle
......
......@@ -36,45 +36,6 @@ std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor(
return std::make_unique<pten::DenseTensor>(src);
}
pten::Scalar MakePtenScalar(const paddle::framework::Tensor& src) {
PADDLE_ENFORCE_EQ(src.numel(),
1,
paddle::platform::errors::InvalidArgument(
"The Scalar only supports Tensor with 1 element, "
"but now Tensor has %d element.",
src.numel()));
switch (src.type()) {
case paddle::framework::proto::VarType::FP32:
return {src.template data<float>()[0]};
case paddle::framework::proto::VarType::FP64:
return {src.template data<double>()[0]};
case paddle::framework::proto::VarType::FP16:
return {src.template data<float16>()[0]};
case paddle::framework::proto::VarType::BF16:
return {src.template data<bfloat16>()[0]};
case paddle::framework::proto::VarType::INT32:
return {src.template data<int32_t>()[0]};
case paddle::framework::proto::VarType::INT64:
return {src.template data<int64_t>()[0]};
case paddle::framework::proto::VarType::INT16:
return {src.template data<int16_t>()[0]};
case paddle::framework::proto::VarType::INT8:
return {src.template data<int8_t>()[0]};
case paddle::framework::proto::VarType::UINT8:
return {src.template data<uint8_t>()[0]};
case paddle::framework::proto::VarType::BOOL:
return {src.template data<bool>()[0]};
case paddle::framework::proto::VarType::COMPLEX64:
return {src.template data<complex64>()[0]};
case paddle::framework::proto::VarType::COMPLEX128:
return {src.template data<complex128>()[0]};
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Data type error. Don't support casting a %d LoDTensor to Scalar.",
src.type()));
}
}
pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) {
auto expected_place = pten::TransToPtenPlace(pten::Backend::CPU);
if (variable.IsType<framework::LoDTensor>()) {
......@@ -82,9 +43,9 @@ pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) {
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
return MakePtenScalar(tmp_tensor);
return {tmp_tensor};
} else {
return MakePtenScalar(tensor);
return {tensor};
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......@@ -95,17 +56,7 @@ pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) {
}
pten::ScalarArray MakePtenScalarArray(const paddle::framework::Tensor& src) {
if (src.type() == paddle::framework::proto::VarType::INT64) {
return {src.data<int64_t>(), src.numel()};
} else if (src.type() == paddle::framework::proto::VarType::INT32) {
return {src.data<int32_t>(), src.numel()};
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Data type error. When cast a LoDTensor to ScalarArray, "
"the data type of LoDTensor must be int32 or int64, "
"but now data type is %s.",
src.type()));
}
return {src};
}
pten::ScalarArray MakePtenScalarArrayFromVar(
......@@ -128,6 +79,7 @@ pten::ScalarArray MakePtenScalarArrayFromVar(
}
}
// TODO(chentianyu03): Inplace with ScalarArray constructor
pten::ScalarArray MakePtenScalarArrayFromVarList(
const std::vector<framework::Variable*>& variable_list) {
if (variable_list.size() == 0) {
......@@ -135,45 +87,28 @@ pten::ScalarArray MakePtenScalarArrayFromVarList(
}
auto expected_place = pten::TransToPtenPlace(pten::Backend::CPU);
paddle::framework::proto::VarType::Type data_type;
auto* first_var = variable_list.front();
if (first_var->IsType<framework::LoDTensor>()) {
const auto& tensor = first_var->Get<framework::LoDTensor>();
data_type = tensor.type();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(first_var->Type())));
}
std::vector<int64_t> vector_data;
vector_data.reserve(variable_list.size());
if (data_type == paddle::framework::proto::VarType::INT64) {
for (auto* var : variable_list) {
paddle::framework::proto::VarType::Type data_type;
if (var->IsType<framework::LoDTensor>()) {
const auto& tensor = var->Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
data_type = tensor.type();
if (data_type == paddle::framework::proto::VarType::INT64) {
const auto& tensor = var->Get<framework::LoDTensor>();
if (tensor.IsInitialized() &&
!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
vector_data.push_back(*tmp_tensor.data<int64_t>());
} else {
vector_data.push_back(*tensor.data<int64_t>());
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(var->Type())));
}
}
} else if (data_type == paddle::framework::proto::VarType::INT32) {
for (auto* var : variable_list) {
if (var->IsType<framework::LoDTensor>()) {
const auto& tensor = var->Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
if (tensor.IsInitialized() &&
!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
vector_data.push_back(*tmp_tensor.data<int32_t>());
......@@ -181,21 +116,24 @@ pten::ScalarArray MakePtenScalarArrayFromVarList(
vector_data.push_back(*tensor.data<int32_t>());
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(var->Type())));
}
}
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
PADDLE_THROW(pten::errors::InvalidArgument(
"Data type error. When cast a LoDTensor to VectorTensor, "
"the data type of LoDTensor must be int32 or int64, "
"but now data type is %s.",
data_type));
}
} else {
PADDLE_THROW(pten::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(var->Type())));
}
}
pten::ScalarArray result{vector_data};
result.setInitByTensor(true);
return {vector_data};
return result;
}
void ResetTensorDtypeAndLayoutByArgDef(pten::TensorBase* dst,
......
......@@ -33,8 +33,6 @@ namespace experimental {
std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor(
const paddle::framework::Tensor& src);
pten::Scalar MakePtenScalar(const paddle::framework::Tensor& src);
pten::ScalarArray MakePtenScalarArray(const paddle::framework::Tensor& src);
pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable);
......
......@@ -25,6 +25,7 @@ namespace experimental {
template <typename T>
class ScalarBase {
public:
bool IsInitByTensor() const { return is_init_by_tensor_; }
// Constructor support implicit
ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT
data_.f64 = val;
......@@ -103,6 +104,7 @@ class ScalarBase {
// The Tensor must have one dim
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT
is_init_by_tensor_ = true;
PD_CHECK(
tensor.numel() == 1,
"The Scalar only supports Tensor with 1 element, but now Tensor has `",
......@@ -194,6 +196,7 @@ class ScalarBase {
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);
private:
bool is_init_by_tensor_{false};
DataType dtype_;
union data {
bool b;
......
......@@ -43,8 +43,13 @@ class ScalarArrayBase {
AssignData(date_value, n);
}
bool IsInitByTensor() const { return is_init_by_tensor_; }
void setInitByTensor(bool val) { is_init_by_tensor_ = val; }
// The Tensor must have one dim
ScalarArrayBase(const T& tensor) { // NOLINT
is_init_by_tensor_ = true;
size_t n = tensor.numel();
array_.reserve(n);
switch (tensor.dtype()) {
......@@ -66,41 +71,17 @@ class ScalarArrayBase {
// The Tensor in vec must have only one element
ScalarArrayBase(const std::vector<T>& tensor_list) { // NOLINT
auto n = tensor_list.size();
array_.reserve(n);
if (!tensor_list.empty()) {
DataType data_type = tensor_list[0].dtype();
is_init_by_tensor_ = true;
for (size_t i = 0; i < tensor_list.size(); ++i) {
DataType data_type = tensor_list[i].dtype();
switch (data_type) {
case DataType::INT32: {
for (size_t i = 0; i < n; ++i) {
PD_CHECK(tensor_list[i].dtype() == data_type,
"The data_type of tensors in the list isn't consistent."
"the first tensor is`",
data_type,
"` but `",
i,
"`th tensor is`",
tensor_list[i].dtype(),
"`.");
case DataType::INT32:
array_.push_back(*tensor_list[i].template data<int32_t>());
}
break;
}
case DataType::INT64: {
for (size_t i = 0; i < n; ++i) {
PD_CHECK(tensor_list[i].dtype() == data_type,
"The data_type of tensors in the list isn't consistent."
"the first tensor is`",
data_type,
"` but `",
i,
"`th tensor is`",
tensor_list[i].dtype(),
"`.");
case DataType::INT64:
array_.push_back(*tensor_list[i].template data<int64_t>());
}
break;
}
default:
PD_THROW(
"Data type error. Currently, The data type of ScalarArrayBase "
......@@ -136,6 +117,7 @@ class ScalarArrayBase {
// TODO(zhangyunfei) Replace std::vector with a more efficient container
// structure.
std::vector<int64_t> array_;
bool is_init_by_tensor_{false};
};
using ScalarArray =
......
......@@ -77,6 +77,7 @@ class ArgumentMappingContext {
virtual bool HasInput(const std::string& name) const = 0;
virtual bool HasOutput(const std::string& name) const = 0;
virtual bool HasAttr(const std::string& name) const = 0;
// now we can't use Attribute here, it will cause pten relay on
// boost::variant and BlockDesc
......
......@@ -146,6 +146,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}
};
// TODO(chenweihang): support other attr type later
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
......
......@@ -23,8 +23,12 @@ void MatmulGradInferMeta(const MetaTensor& x,
bool transpose_y,
MetaTensor* dx,
MetaTensor* dy) {
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
}
} // namespace pten
......@@ -315,4 +315,137 @@ void TransferLayoutInferMeta(const MetaTensor& x,
out->set_layout(layout);
}
void SplitInferMeta(const MetaTensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis,
std::vector<MetaTensor>* out,
MetaConfig config) {
int axis_value = axis.to<int>();
int rank = x.dims().size();
PADDLE_ENFORCE_EQ(
axis_value >= -rank && axis_value < rank,
true,
paddle::platform::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank,
rank,
axis_value));
if (axis_value < 0) {
axis_value = axis_value + rank;
}
auto input_axis_dim = x.dims().at(axis_value);
auto num_or_sections_data = num_or_sections.GetData();
// step1: get formated sections
std::vector<int64_t> sections;
// num_or_sections is a number
if (num_or_sections_data.size() == 1) {
int num = num_or_sections_data.at(0);
PADDLE_ENFORCE_EQ(input_axis_dim % num,
0,
paddle::platform::errors::InvalidArgument(
"The input's size along the split dimension "
"must be evenly divisible by Attr(num_or_sections). "
"But received Attr(num_or_sections) "
"= %d, input(X)'s shape = [%s], Attr(dim) = %d.",
num,
x.dims(),
axis_value));
for (int i = 0; i < num; ++i) {
sections.push_back(input_axis_dim / num);
}
} else {
// num_or_sections is a sections
const int unknow_dim_val = -1;
int unknow_dim_idx = -1;
int num_of_unknow = 0;
int sum_of_section = 0;
for (size_t i = 0; i < num_or_sections_data.size(); ++i) {
sections.push_back(num_or_sections_data[i]);
if (num_or_sections_data[i] == unknow_dim_val) {
num_of_unknow++;
unknow_dim_idx = i;
} else {
sum_of_section += num_or_sections_data[i];
}
}
if (config.is_runtime) {
PADDLE_ENFORCE_LE(num_of_unknow,
1,
paddle::platform::errors::InvalidArgument(
"Only one dimension value of Attr(num_or_sections) "
"in SplitOp can be -1. "
"But received Attr(num_or_sections) = [%s].",
pten::framework::make_ddim(num_or_sections_data)));
}
if (unknow_dim_idx != -1) {
// for example, input shape = [4 ,5], axis = 1, sections = [2, 3, -1].
// input_axis_dim = 5, sum_of_sections = 5.
// the following check will fail.
PADDLE_ENFORCE_LT(
sum_of_section,
input_axis_dim,
paddle::platform::errors::InvalidArgument(
"Sum of Attr(num_or_sections) other than unknown section "
"must be less than the input's "
"size "
"along the split dimension. But received Attr(num_or_sections) "
"= [%s], input(X)'s shape = [%s], Attr(dim) = %d.",
pten::framework::make_ddim(num_or_sections_data),
x.dims(),
axis_value));
if (config.is_runtime) {
sections[unknow_dim_idx] = input_axis_dim - sum_of_section;
}
} else {
PADDLE_ENFORCE_EQ(
sum_of_section,
input_axis_dim,
paddle::platform::errors::InvalidArgument(
"Sum of Attr(num_or_sections) must be equal to the input's "
"size "
"along the split dimension. But received Attr(num_or_sections)"
" = [%s], input(X)'s shape = [%s], Attr(dim) = %d.",
pten::framework::make_ddim(num_or_sections_data),
x.dims(),
axis_value));
}
}
// setp2: fill out dims
std::vector<pten::DDim> out_dims(sections.size(), x.dims());
if (config.is_runtime || input_axis_dim > 0) {
for (size_t i = 0; i < sections.size(); ++i) {
out_dims[i][axis_value] = sections[i];
}
} else {
for (size_t i = 0; i < sections.size(); ++i) {
out_dims[i][axis_value] = -1;
}
}
for (size_t i = 0; i < sections.size(); ++i) {
if (axis_value != 0) {
// Only pass LoD when not spliting along the first dim.
(*out)[i].set_dtype(x.dtype());
(*out)[i].set_dims(out_dims[i]);
(*out)[i].set_layout(x.layout());
} else {
(*out)[i].set_dtype(x.dtype());
(*out)[i].set_dims(out_dims[i]);
(*out)[i].set_layout(x.layout());
(*out)[i].share_lod(x);
}
}
return;
}
} // namespace pten
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/meta_tensor.h"
......@@ -74,4 +75,9 @@ void TransferLayoutInferMeta(const MetaTensor& x,
DataLayout layout,
MetaTensor* out);
void SplitInferMeta(const MetaTensor& x_meta,
const ScalarArray& num_or_sections,
const Scalar& axis,
std::vector<MetaTensor>* out,
MetaConfig config = MetaConfig());
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/split_kernel.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/pten/common/float16.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/cpu/concat_and_split.h"
namespace pten {
template <typename T, typename Context>
void SplitKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) {
// need to infershape output
if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) {
std::vector<MetaTensor> out_metas;
for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]);
}
pten::SplitInferMeta(x, num_or_sections, axis_scalar, &out_metas, true);
for (size_t i = 0; i < out_metas.size(); ++i) {
outs[i]->Resize(out_metas[i].dims());
}
}
std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.Alloc(outs[j]);
shape_refer.emplace_back(outs[j]);
}
int axis = axis_scalar.to<int>();
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
paddle::operators::StridedMemcpyWithAxis0<T>(
dev_ctx, x, shape_refer, &outs);
} else {
SplitImpl<T, Context>(dev_ctx, x, shape_refer, axis, &outs);
}
}
} // namespace pten
PT_REGISTER_KERNEL(split,
CPU,
ALL_LAYOUT,
pten::SplitKernel,
float,
double,
int64_t,
int,
bool,
pten::dtype::float16) {}
......@@ -134,7 +134,7 @@ __global__ void ConcatKernel_(const T** inputs_data,
}
template <typename T>
__global__ void SplitKernel(const T* input_data,
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t* out_cols,
......@@ -184,7 +184,7 @@ __device__ void SplitKernelDetail(const T* input_data,
}
template <typename T>
__global__ void SplitKernel(const T* input_data,
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
......@@ -193,7 +193,7 @@ __global__ void SplitKernel(const T* input_data,
}
template <typename T>
__global__ void SplitKernel(const T* input_data,
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
......@@ -206,7 +206,7 @@ __global__ void SplitKernel(const T* input_data,
}
template <typename T>
__global__ void SplitKernel(const T* input_data,
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
......@@ -221,7 +221,7 @@ __global__ void SplitKernel(const T* input_data,
}
template <typename T>
__global__ void SplitKernel(const T* input_data,
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
......@@ -497,7 +497,7 @@ void SplitImpl(const Context& context,
if (has_same_shape) {
if (o_num == 2) {
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(),
in_row,
in_col,
......@@ -505,7 +505,7 @@ void SplitImpl(const Context& context,
outputs_data[0],
outputs_data[1]);
} else if (o_num == 3) {
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(),
in_row,
in_col,
......@@ -514,7 +514,7 @@ void SplitImpl(const Context& context,
outputs_data[1],
outputs_data[2]);
} else if (o_num == 4) {
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(),
in_row,
in_col,
......@@ -524,7 +524,7 @@ void SplitImpl(const Context& context,
outputs_data[2],
outputs_data[3]);
} else {
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
}
} else {
......@@ -542,7 +542,7 @@ void SplitImpl(const Context& context,
int64_t* dev_outs_col_data =
reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(),
in_row,
in_col,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/split_kernel.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/pten/common/float16.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/gpu/concat_and_split.h"
namespace pten {
template <typename T, typename Context>
void SplitKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) {
// need to infershape output
if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) {
std::vector<MetaTensor> out_metas;
for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]);
}
pten::SplitInferMeta(x, num_or_sections, axis_scalar, &out_metas, true);
for (size_t i = 0; i < out_metas.size(); ++i) {
outs[i]->Resize(out_metas[i].dims());
}
}
std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.Alloc(outs[j]);
shape_refer.emplace_back(outs[j]);
}
int axis = axis_scalar.to<int>();
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
paddle::operators::StridedMemcpyWithAxis0<T>(
dev_ctx, x, shape_refer, &outs);
} else {
SplitImpl<T, Context>(dev_ctx, x, shape_refer, axis, &outs);
}
}
} // namespace pten
PT_REGISTER_KERNEL(split,
GPU,
ALL_LAYOUT,
pten::SplitKernel,
float,
double,
int64_t,
int,
bool,
pten::dtype::float16,
pten::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
template <typename T, typename Context>
void SplitKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis,
std::vector<DenseTensor*> out);
template <typename T, typename Context>
std::vector<DenseTensor> Split(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
size_t out_number;
if (num_or_sections.GetData().size() == 1) {
out_number = num_or_sections.GetData()[0];
} else {
out_number = num_or_sections.GetData().size();
}
std::vector<MetaTensor> out_meta;
out_meta.reserve(out_number);
std::vector<DenseTensor> result;
result.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
auto dense_out = pten::Empty<T, Context>(dev_ctx);
MetaTensor tmp_meta(&dense_out);
result.push_back(dense_out);
out_meta.push_back(&result.back());
}
SplitInferMeta(x, num_or_sections, axis, &out_meta);
std::vector<DenseTensor*> outs;
outs.reserve(out_meta.size());
for (size_t i = 0; i < out_meta.size(); ++i) {
outs.push_back(&result[i]);
}
SplitKernel<T, Context>(dev_ctx, x, num_or_sections, axis, outs);
return result;
}
} // namespace pten
......@@ -17,10 +17,17 @@ limitations under the License. */
namespace pten {
KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasAttr("use_addto")) {
return KernelSignature("addto_matmul_grad",
{"X", "Y", GradVarName("Out")},
{"trans_x", "trans_y", "use_addto"},
{GradVarName("X"), GradVarName("Y")});
} else {
return KernelSignature("matmul_grad",
{"X", "Y", GradVarName("Out")},
{"trans_x", "trans_y"},
{GradVarName("X"), GradVarName("Y")});
}
}
KernelSignature MatmulDoubleGradOpArgumentMapping(
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature SplitOpArgumentMapping(const ArgumentMappingContext& ctx) {
// priority: num > SectionsTensorList > sections
// priority: AxisTensor > axis
if (paddle::any_cast<int>(ctx.Attr("num")) > 0) {
if (ctx.HasInput("AxisTensor")) {
return KernelSignature("split", {"X"}, {"num", "AxisTensor"}, {"Out"});
} else {
return KernelSignature("split", {"X"}, {"num", "axis"}, {"Out"});
}
}
if (ctx.InputSize("SectionsTensorList") > 0) {
if (ctx.HasInput("AxisTensor")) {
return KernelSignature(
"split", {"X"}, {"SectionsTensorList", "AxisTensor"}, {"Out"});
} else {
return KernelSignature(
"split", {"X"}, {"SectionsTensorList", "axis"}, {"Out"});
}
}
if (ctx.HasInput("AxisTensor")) {
return KernelSignature("split", {"X"}, {"sections", "AxisTensor"}, {"Out"});
} else {
return KernelSignature("split", {"X"}, {"sections", "axis"}, {"Out"});
}
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(split, pten::SplitOpArgumentMapping);
......@@ -22,6 +22,6 @@ cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api
cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_conj_api SRCS test_conj_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_concat_api SRCS test_concat_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_split_api SRCS test_split_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_data_transform SRCS test_data_transform.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_sparse_utils_api SRCS test_sparse_utils_api.cc DEPS pten_tensor pten_api pten_api_utils)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/api/include/api.h"
#include "paddle/pten/api/include/manual_api.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace paddle {
namespace tests {
namespace framework = paddle::framework;
using DDim = pten::framework::DDim;
// TODO(chentianyu03): Remove this test after the API is used in the dygraph
TEST(API, split) {
// 1. create tensor
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc.get(),
pten::DenseTensorMeta(pten::DataType::FLOAT32,
pten::framework::make_ddim({4, 10}),
pten::DataLayout::NCHW));
auto* dense_x_data =
dense_x->mutable_data<float>(paddle::platform::CPUPlace());
for (size_t i = 0; i < 4; ++i) {
for (size_t j = 0; j < 10; ++j) {
dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0;
}
}
paddle::experimental::Tensor x(dense_x);
// 2. test API
auto out = paddle::experimental::split(x, {2, 2}, 0);
// 3. check result
ASSERT_EQ(out.size(), static_cast<size_t>(2));
ASSERT_EQ(out[0].dims().size(), 2);
ASSERT_EQ(out[0].dims()[0], 2);
ASSERT_EQ(out[0].dims()[1], 10);
ASSERT_EQ(out[0].type(), pten::DataType::FLOAT32);
ASSERT_EQ(out[0].layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out[1].dims().size(), 2);
ASSERT_EQ(out[1].dims()[0], 2);
ASSERT_EQ(out[1].dims()[1], 10);
ASSERT_EQ(out[1].type(), pten::DataType::FLOAT32);
ASSERT_EQ(out[1].layout(), pten::DataLayout::NCHW);
auto out_data_0 = std::dynamic_pointer_cast<pten::DenseTensor>(out[0].impl())
->data<float>();
auto out_data_1 = std::dynamic_pointer_cast<pten::DenseTensor>(out[1].impl())
->data<float>();
for (size_t i = 0; i < 4; ++i) {
if (i < 20) {
ASSERT_NEAR(dense_x_data[i], out_data_0[i], 1e-6);
} else {
ASSERT_NEAR(dense_x_data[i], out_data_1[i - 20], 1e-6);
}
}
}
} // namespace tests
} // namespace paddle
......@@ -11,4 +11,5 @@ cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_uti
cc_test(test_sum_dev_api SRCS test_sum_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_split_dev_api SRCS test_split_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS pten pten_api_utils)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/kernels/split_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/include/manual_api.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten {
namespace tests {
namespace framework = paddle::framework;
using DDim = pten::framework::DDim;
TEST(DEV_API, split) {
// 1. create tensor
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
pten::CPUPlace());
pten::DenseTensor dense_x(
alloc.get(),
pten::DenseTensorMeta(pten::DataType::FLOAT32,
pten::framework::make_ddim({4, 10}),
pten::DataLayout::NCHW));
pten::CPUContext dev_ctx;
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx.Init();
auto* dense_x_data = dev_ctx.Alloc<float>(&dense_x);
for (size_t i = 0; i < 4; ++i) {
for (size_t j = 0; j < 10; ++j) {
dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0;
}
}
// 2. test API
auto out = pten::Split<float>(dev_ctx, dense_x, {2, 2}, 0);
// 3. check result
ASSERT_EQ(out.size(), static_cast<size_t>(2));
ASSERT_EQ(out[0].dims().size(), 2);
ASSERT_EQ(out[0].dims()[0], 2);
ASSERT_EQ(out[0].dims()[1], 10);
ASSERT_EQ(out[0].meta().dtype, pten::DataType::FLOAT32);
ASSERT_EQ(out[0].meta().layout, pten::DataLayout::NCHW);
ASSERT_EQ(out[1].dims().size(), 2);
ASSERT_EQ(out[1].dims()[0], 2);
ASSERT_EQ(out[1].dims()[1], 10);
ASSERT_EQ(out[1].meta().dtype, pten::DataType::FLOAT32);
ASSERT_EQ(out[1].meta().layout, pten::DataLayout::NCHW);
auto out_data_0 = out[0].data<float>();
auto out_data_1 = out[1].data<float>();
for (size_t i = 0; i < 4; ++i) {
if (i < 20) {
ASSERT_NEAR(dense_x_data[i], out_data_0[i], 1e-6);
} else {
ASSERT_NEAR(dense_x_data[i], out_data_1[i - 20], 1e-6);
}
}
}
} // namespace tests
} // namespace pten
......@@ -1759,11 +1759,11 @@ set +x
set -x
ut_endTime_s=`date +%s`
echo "XPU testCase Time: $[ $ut_endTime_s - $ut_startTime_s ]s"
python ${PADDLE_ROOT}/build/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py
unset XPU_OP_LIST_DIR
if [[ "$EXIT_CODE" != "0" ]]; then
exit 8;
fi
python ${PADDLE_ROOT}/build/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py
unset XPU_OP_LIST_DIR
fi
}
......
......@@ -21,25 +21,6 @@ from .pass_base import PassBase, register_pass
from paddle.fluid.transpiler.details.program_utils import delete_ops
from paddle.fluid.transpiler.collective import SingleProcessMultiThread
OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
SPARSE_GRAD_OP_TYPE_DICT = {
"lookup_table_grad": "W",
"lookup_table_v2_grad": "W"
}
DEVICE_LIST = ["cpu", "gpu", "xpu"]
COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"]
DEFAULT_DEVICE = 'cpu'
@register_pass("append_send_ops_pass")
class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用
......@@ -894,6 +875,100 @@ class SplitTrainerOpsPass(PassBase):
def _check_conflict(self, other_pass):
return True
def _replace_ops_by_communicate_op(self, program, attrs, heter_block_index,
ops_list, block_var_detail):
all_op = program.global_block().ops
start_op = ops_list[0]
first_op_idx = -1
for op in all_op:
if str(op) == str(start_op):
first_op_idx = all_op.index(op)
break
assert first_op_idx != -1
self._delete_same_ops(program.global_block(), ops_list)
entrance_var = []
role_maker = attrs['role_maker']
if heter_block_index == 1:
next_heter_worker_endpoints = get_next_stage_trainers(role_maker)
entrance_var = block_var_detail[heter_block_index]["forward"][
"entrance"]
comm_info = get_communicate_var_info(program, heter_block_index + 1,
entrance_var)
program.global_block()._insert_op(
index=first_op_idx,
type="send_and_recv",
inputs={"X": program.global_block().vars[entrance_var[0]]},
outputs={"Out": []},
attrs={
"mode": "forward",
"send_var_name": entrance_var + ["microbatch_id"],
"recv_var_name": [],
"message_name": comm_info["block_input_var_name"],
"next_endpoints": next_heter_worker_endpoints,
"previous_endpoints": [],
"trainer_id": get_role_id(role_maker),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
return entrance_var
def _delete_same_ops(self, block, ops):
for op in ops:
try:
for origin_op in block.ops:
if str(origin_op) == str(op):
idx = list(block.ops).index(origin_op)
block._remove_op(idx)
break
except Exception as e:
print(e)
def _remove_var_pair_by_grad(self, var_name, attrs):
for index, pair in enumerate(attrs['merged_variables_pairs']):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del attrs['merged_variables_pairs'][index]
for index, pair in enumerate(attrs['merged_dense_pairs']):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del attrs['merged_dense_pairs'][index]
return
for index, pair in enumerate(attrs['merged_sparse_pairs']):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del attrs['merged_sparse_pairs'][index]
return
def _remove_trainer_send_op(self, program, attrs, heter_block_index,
block_var_detail):
# if trainer do FF->BP->SEND, it has follow vars: var, var@GRAD
# if trainer only do SEND, it has one var: var@GRAD
# Delete Send op ,if trainer doesn't has pair var (var<->var@GRAD)
persistables = block_var_detail[heter_block_index]["forward"]["persistables"] + \
block_var_detail[heter_block_index]["backward"]["persistables"]
need_remove_send_op = []
need_remove_grad_var = []
for op in find_send_op(program):
input_list, _ = find_op_input_output(program,
program.global_block(), op)
for var_name in input_list:
origin_var_name = var_name.split("@GRAD")[0]
if origin_var_name in persistables:
need_remove_send_op.append(op)
need_remove_grad_var.append(var_name)
need_remove_send_op = list(set(need_remove_send_op))
delete_ops(program.global_block(), need_remove_send_op)
for grad_var_name in need_remove_grad_var:
self._remove_var_pair_by_grad(grad_var_name, attrs)
def _create_trainer_program(self, program, origin_program, attrs,
program_block_ops_list, block_var_detail):
# This function mainly includes the following contents:
......@@ -911,18 +986,18 @@ class SplitTrainerOpsPass(PassBase):
ops_list = program_block_ops_list[heter_block_index][
"forward"] + program_block_ops_list[heter_block_index][
"backward"]
static_var += replace_ops_by_communicate_op(
static_var += self._replace_ops_by_communicate_op(
program, attrs, heter_block_index, ops_list, block_var_detail)
remove_trainer_send_op(program, attrs, heter_block_index,
self._remove_trainer_send_op(program, attrs, heter_block_index,
block_var_detail)
optimizer_block = []
grad_to_block_id = []
bp_ops_list = program_block_ops_list[0]["backward"]
delete_same_ops(program.global_block(), bp_ops_list)
delete_trainer_useless_var(attrs, program, static_var)
backward_block = create_backward_block(program, origin_program, attrs,
self._delete_same_ops(program.global_block(), bp_ops_list)
delete_trainer_useless_var(program, static_var)
backward_block = create_backward_block(program, origin_program,
bp_ops_list, block_var_detail)
bp_entrance_vars = block_var_detail[0]["backward"]["entrance"]
......
......@@ -186,10 +186,10 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
add_lr_decay_table_pass.apply([], [], self.pass_ctx)
distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs)
distributed_ops_pass.apply([self.cloned_main], [], self.pass_ctx)
distributed_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
delete_optimizer_pass = new_pass("delete_optimizer_pass", self.attrs)
delete_optimizer_pass.apply([None], [_startup], self.pass_ctx)
delete_optimizer_pass.apply([self.cloned_main], [None], self.pass_ctx)
append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs)
append_send_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
......@@ -210,12 +210,13 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
else:
split_trainer_ops_pass = new_pass("split_trainer_ops_pass",
self.attrs)
split_trainer_ops_pass([self.cloned_main], [], self.pass_ctx)
split_trainer_ops_pass.apply([self.cloned_main], [None],
self.pass_ctx)
set_heter_pipeline_opt_pass = new_pass('set_heter_pipeline_opt_pass',
self.attrs)
set_heter_pipeline_opt_pass.apply([self.cloned_main],
[self.cloned_startup], pass_ctx)
[self.cloned_startup], self.pass_ctx)
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
......@@ -228,7 +229,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
ps_set_heter_pipeline_opt_pass = new_pass(
"set_heter_pipeline_opt_pass", self.attrs)
ps_set_heter_pipeline_opt_pass.apply(
[self.loss.block.program], [startup_program], self.pass_ctx)
[self.cloned_main], [self.cloned_startup], self.pass_ctx)
elif self.attrs['is_server']:
self._build_pserver_programs()
......
......@@ -42,9 +42,17 @@ RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
backward = core.op_proto_and_checker_maker.OpRole.Backward
DEVICE_LIST = ["cpu", "gpu", "xpu"]
COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"]
SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"]
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
SPARSE_GRAD_OP_TYPE_DICT = {
"lookup_table_grad": "W",
"lookup_table_v2_grad": "W"
}
DEFAULT_DEVICE = 'cpu'
def logger_config(log_path, logging_name):
......@@ -640,6 +648,20 @@ def find_block_joints(program, program_block_ops_list, heter_ops):
return block_var_detail
def find_ops_list_input_output(program, ops_list):
input_var_list = []
output_var_list = []
for op in ops_list:
inputs = _get_input_map_from_op(program.global_block().vars, op)
input_var_list += get_varlist_from_op_map(inputs)
outputs = _get_output_map_from_op(program.global_block().vars, op)
output_var_list += get_varlist_from_op_map(outputs)
input_var_list = list(set(input_var_list))
output_var_list = list(set(output_var_list))
return input_var_list, output_var_list
def find_entrance_exit_private(program, program_block_ops_list):
block_var_detail = []
persistables = []
......@@ -850,6 +872,54 @@ def _get_output_map_from_op(varmap, op):
return iomap
def get_varlist_from_op_map(var_map):
var_list = []
for key, varlist in six.iteritems(var_map):
if not isinstance(varlist, list):
varlist = [varlist]
for i in range(len(varlist)):
var = varlist[i]
var_list.append(var.name)
return var_list
def _get_input_map_from_op(varmap, op):
"""Returns a dict from op input name to the vars in varmap."""
iomap = collections.OrderedDict()
for key in op.input_names:
vars = []
for varname in op.input(key):
if varname == "@EMPTY@":
continue
if "lod_tensor_blocking_queue" in varname:
continue
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def screen_persistables(program, var_list):
need_remove = []
for var_name in var_list:
if "@GRAD" in var_name:
if "GRAD" != var_name.split("@")[-1]:
continue
origin_var_name = var_name.split("@GRAD")[0]
var = program.global_block().vars[origin_var_name]
else:
var = program.global_block().vars[var_name]
if fluid.io.is_persistable(var):
need_remove.append(var_name)
for var_name in need_remove:
var_list.remove(var_name)
return need_remove
def block_append_op(program, origin_program, block, op):
merge_ordereddict = origin_program.global_block().vars.copy()
merge_ordereddict.update(block.vars)
......@@ -1154,6 +1224,84 @@ def get_param_grads(origin_program):
return sparse_param_grads, dense_param_grads
def delete_ops(block, ops):
for op in ops:
try:
idx = list(block.ops).index(op)
block._remove_op(idx)
except Exception as e:
print(e)
def find_send_op(program):
send_op_list = []
for op in program.global_block().ops:
if op.type == "send":
send_op_list.append(op)
return send_op_list
def find_op_input_output(program, block, op):
input_var_list = []
output_var_list = []
inputs = _get_input_map_from_op(block.vars, op)
input_var_list += get_varlist_from_op_map(inputs)
outputs = _get_output_map_from_op(block.vars, op)
output_var_list += get_varlist_from_op_map(outputs)
input_var_list = list(set(input_var_list))
output_var_list = list(set(output_var_list))
return input_var_list, output_var_list
def get_vars_name_in_block(block):
vars_list = block.vars.keys()
vars_name_list = [var_name for var_name in vars_list]
return vars_name_list
def delete_trainer_useless_var(program, static_var):
static_var = list(set(static_var))
program_useful_var_list = []
for op in program.global_block().ops:
input_var_list, output_var_list = find_op_input_output(
program, program.global_block(), op)
op_var_list = list(set(input_var_list).union(set(output_var_list)))
program_useful_var_list = list(
set(program_useful_var_list).union(set(op_var_list)))
program_useful_var_list += static_var
program_useless_var_list = list(
set(get_vars_name_in_block(program.global_block())).difference(
set(program_useful_var_list)))
for var in program_useless_var_list:
program.global_block()._remove_var(var)
return program_useless_var_list
def create_backward_block(program, origin_program, bp_ops_list,
block_var_detail):
pre_block_idx = program.num_blocks - 1
heter_block = program._create_block(pre_block_idx)
for _, op in enumerate(bp_ops_list):
if op.type == "send":
send_varnames = op.attr('send_varnames')
is_skip = False
for varname in send_varnames:
if varname not in program.global_block(
).vars and varname not in heter_block.vars:
is_skip = True
break
if is_skip == True:
continue
block_append_op(program, origin_program, heter_block, op)
entrance_vars = block_var_detail[0]["backward"]["entrance"]
add_vars_by_var_list(entrance_vars, origin_program, program, heter_block)
exit_vars = block_var_detail[0]["backward"]["exit"]
add_vars_by_var_list(exit_vars, origin_program, program, heter_block)
return heter_block
def debug_program(file, program, is_trainer):
if is_trainer:
with open(file, 'w+') as f:
......
......@@ -21,6 +21,17 @@ from paddle.fluid import core
from paddle.fluid import framework
from paddle import _C_ops
final_state_name_mapping = {
"matmul_v2": {
"final_op_name": "final_state_matmul",
"transpose_x": "trans_x",
"transpose_y": "trans_y",
"x": "X",
"y": "Y",
"out": "Out",
}
}
class Tracer(core.Tracer):
"""
......@@ -40,17 +51,13 @@ class Tracer(core.Tracer):
self._train_mode = True
def trace_op(self,
def eager_trace_op(self,
type,
inputs,
outputs,
attrs,
stop_gradient=False,
inplace_map=None):
if framework._in_eager_mode():
# inputs : {"sum": [tensor], ...}
# outputs : {"sum": [tensor], ...}
function_ptr = _C_ops.__dict__[type]
core_ops_args_info = _C_ops.get_core_ops_args_info()
......@@ -107,22 +114,127 @@ class Tracer(core.Tracer):
# Replaced outputs by function returns
if isinstance(returns[i], list):
for j in range(len(returns[i])):
outputs[retname][j].reconstruct_from_(returns[i]
[j])
outputs[retname][j].reconstruct_from_(returns[i][j],
False)
else:
outputs[retname][0].reconstruct_from_(returns[i])
outputs[retname][0].reconstruct_from_(returns[i], False)
elif isinstance(returns, list):
assert len(outputs.keys()) == 1
key = list(outputs.keys())[0]
for j in range(len(returns)):
outputs[key][j].reconstruct_from_(returns[j])
outputs[key][j].reconstruct_from_(returns[j], False)
else:
assert len(outputs.keys()) == 1
key = list(outputs.keys())[0]
if isinstance(outputs[key], list):
outputs[key][0].reconstruct_from_(returns)
outputs[key][0].reconstruct_from_(returns, False)
else:
outputs[key].reconstruct_from_(returns, False)
def eager_final_state_trace_op(self,
type,
inputs,
outputs,
attrs,
stop_gradient=False,
inplace_map=None):
assert type in final_state_name_mapping.keys()
final_state_type = final_state_name_mapping[type]["final_op_name"]
function_ptr = _C_ops.__dict__[final_state_type]
core_ops_args_info = _C_ops.get_final_state_core_ops_args_info()
core_ops_args_type_info = _C_ops.get_final_state_core_ops_args_type_info(
)
core_ops_returns_info = _C_ops.get_final_state_core_ops_returns_info()
op_args = core_ops_args_info[final_state_type]
op_args_type = core_ops_args_type_info[final_state_type]
op_returns = core_ops_returns_info[final_state_type]
arg_list = []
for i in range(len(op_args)):
eager_arg_name = op_args[i]
arg_type = op_args_type[i]
assert eager_arg_name in final_state_name_mapping[type].keys()
arg_name = final_state_name_mapping[type][eager_arg_name]
if arg_name in inputs.keys():
arg_to_append = inputs[arg_name]
elif arg_name in outputs.keys():
arg_to_append = outputs[arg_name]
elif arg_name in attrs.keys() and arg_type == "":
arg_to_append = attrs[arg_name]
else:
# dispensable
arg_to_append = None
if arg_type == "":
# attribute
arg_list.append(arg_to_append)
elif arg_type == "tensor":
if isinstance(arg_to_append, list):
arg_list.append(arg_to_append[0])
else:
arg_list.append(arg_to_append)
elif arg_type == "list":
assert isinstance(arg_to_append, list)
arg_list.append(arg_to_append)
else:
assert arg_to_append is None
arg_list.append(arg_to_append)
returns = function_ptr(*arg_list)
if isinstance(returns, tuple):
for i in range(len(op_returns)):
eager_retname = op_returns[i]
assert eager_retname in final_state_name_mapping[type].keys()
retname = final_state_name_mapping[type][eager_retname]
if retname in outputs.keys():
# Replaced outputs by function returns
if isinstance(returns[i], list):
for j in range(len(returns[i])):
outputs[retname][j].reconstruct_from_(returns[i][j],
False)
else:
outputs[retname][0].reconstruct_from_(returns[i], False)
elif isinstance(returns, list):
assert len(outputs.keys()) == 1
key = list(outputs.keys())[0]
for j in range(len(returns)):
outputs[key][j].reconstruct_from_(returns[j], False)
else:
assert len(outputs.keys()) == 1
key = list(outputs.keys())[0]
if isinstance(outputs[key], list):
outputs[key][0].reconstruct_from_(returns, False)
else:
outputs[key].reconstruct_from_(returns, False)
def trace_op(self,
type,
inputs,
outputs,
attrs,
stop_gradient=False,
inplace_map=None):
if framework._in_eager_mode():
# inputs : {"sum": [tensor], ...}
# outputs : {"sum": [tensor], ...}
if type in final_state_name_mapping.keys():
final_state_type = final_state_name_mapping[type][
"final_op_name"]
assert final_state_type in _C_ops.__dict__
self.eager_final_state_trace_op(type, inputs, outputs, attrs,
stop_gradient, inplace_map)
else:
outputs[key].reconstruct_from_(returns)
self.eager_trace_op(type, inputs, outputs, attrs, stop_gradient,
inplace_map)
else:
self.trace(type, inputs, outputs, attrs,
framework._current_expected_place(), self._has_grad and
......
......@@ -22,6 +22,7 @@ import inspect
import unittest
import numpy as np
from collections import OrderedDict
from paddle.distributed.ps.utils.public import logger
from dist_pass_test_base import prepare_python_path_and_return_module, remove_path_if_exists
import paddle.distributed.fleet as fleet
......@@ -37,7 +38,7 @@ class PsPassTestBase(unittest.TestCase):
print('Ps tearDown...')
def ps_launch(self, config, ps_mode="cpu-ps"):
if ps_mode == "cpu-ps":
if ps_mode == "cpu-ps" or ps_mode == 'heter-ps':
os.environ['WITH_DISTRIBUTE'] = 'ON'
cmd = [
......@@ -45,7 +46,16 @@ class PsPassTestBase(unittest.TestCase):
"-u",
] + [
"-m", "launch", "--log_dir", config['log_dir'], "--worker_num",
config['worker_num'], "--server_num", config['server_num'],
config['worker_num'], "--server_num", config['server_num']
]
if ps_mode == 'heter-ps':
os.environ['FLAGS_START_PORT'] = '12004'
cmd += [
'--heter_worker_num', config['heter_worker_num'],
'--heter_devices', config['heter_devices']
]
cmd += [
"../ps/ps_dnn_trainer.py", "-m", config['ps_mode_config'],
"--run_minimize", config['run_minimize'], "--run_single_pass",
config['run_single_pass'], "--debug_new_pass",
......
......@@ -63,6 +63,27 @@ class TestPsTrainerPass(PsPassTestBase):
self.check()
# heter ps 三阶段待测
def test_ps_optimizer_minimize_heter(self):
self.init()
self.config['worker_num'] = "2"
self.config['server_num'] = "2"
self.config['heter_worker_num'] = '2'
self.config['heter_devices'] = 'gpu'
self.config['run_minimize'] = '1'
self.config['ps_mode_config'] = "../ps/heter_ps_config.yaml"
self.config['debug_new_minimize'] = '0'
self.config['log_dir'] = "/heter_log_old_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config, 'heter-ps')
self.config['debug_new_minimize'] = '1'
self.config['log_dir'] = "/heter_log_new_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch(self.config, 'heter-ps')
def test_ps_optimizer_minimize_gpu(self):
self.init()
self.config['run_minimize'] = '1'
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
hyper_parameters:
optimizer:
class: Adam
learning_rate: 0.0001
strategy: async # 有用
sparse_inputs_slots: 27
sparse_feature_number: 1024
sparse_feature_dim: 11
dense_input_dim: 13
fc_sizes: [512, 256, 128, 32]
distributed_embedding: 0
runner:
sync_mode: "heter"
thread_num: 8
micro_num: 8 # micro batch num for each thread
pipeline: True
model_path: "../ps_dnn_model.py"
......@@ -23,7 +23,6 @@ import yaml, six, copy
import paddle
import os
import warnings
import logging
import ast
import numpy as np
import struct
......@@ -176,6 +175,10 @@ def get_user_defined_strategy(config):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
strategy.a_sync_configs = {"heter_worker_device_guard": "gpu"}
strategy.pipeline = True
strategy.pipeline_configs = {
"accumulate_steps": config.get('runner.micro_num')
}
elif sync_mode == "gpubox":
print("sync_mode = {}".format(sync_mode))
strategy = paddle.distributed.fleet.DistributedStrategy()
......@@ -328,6 +331,7 @@ class DnnTrainer(object):
if self.config['debug_new_minimize'] == 1:
logger.info("entering run_minimize -- new")
self.role_maker._generate_role() # 必要
from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer
ps_optimizer = ParameterServerOptimizer(inner_optimizer)
ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer,
......
......@@ -17,6 +17,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
import math
import paddle.distributed.fleet as fleet
from paddle.distributed.ps.utils.public import logger
class DNNLayer(nn.Layer):
......@@ -77,6 +78,11 @@ class DNNLayer(nn.Layer):
y_dnn = paddle.concat(x=sparse_embs + [dense_inputs], axis=1)
if self.sync_mode == 'heter':
with paddle.fluid.device_guard('gpu'):
for n_layer in self._mlp_layers:
y_dnn = n_layer(y_dnn)
else:
for n_layer in self._mlp_layers:
y_dnn = n_layer(y_dnn)
......
#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
......@@ -24,38 +23,39 @@ from op_test import OpTest
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.op import Operator
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _test_eager_guard
class TestWhereOp(OpTest):
def setUp(self):
self.op_type = "where"
self.op_type = 'where'
self.init_config()
self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y}
self.outputs = {'Out': np.where(self.cond, self.x, self.y)}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out')
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
def init_config(self):
self.x = np.random.uniform(-3, 5, (100)).astype("float64")
self.y = np.random.uniform(-3, 5, (100)).astype("float64")
self.cond = np.zeros((100)).astype("bool")
self.x = np.random.uniform((-3), 5, 100).astype('float64')
self.y = np.random.uniform((-3), 5, 100).astype('float64')
self.cond = np.zeros(100).astype('bool')
class TestWhereOp2(TestWhereOp):
def init_config(self):
self.x = np.random.uniform(-5, 5, (60, 2)).astype("float64")
self.y = np.random.uniform(-5, 5, (60, 2)).astype("float64")
self.cond = np.ones((60, 2)).astype("bool")
self.x = np.random.uniform((-5), 5, (60, 2)).astype('float64')
self.y = np.random.uniform((-5), 5, (60, 2)).astype('float64')
self.cond = np.ones((60, 2)).astype('bool')
class TestWhereOp3(TestWhereOp):
def init_config(self):
self.x = np.random.uniform(-3, 5, (20, 2, 4)).astype("float64")
self.y = np.random.uniform(-3, 5, (20, 2, 4)).astype("float64")
self.x = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64')
self.y = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64')
self.cond = np.array(np.random.randint(2, size=(20, 2, 4)), dtype=bool)
......@@ -66,15 +66,15 @@ class TestWhereAPI(unittest.TestCase):
def init_data(self):
self.shape = [10, 15]
self.cond = np.array(np.random.randint(2, size=self.shape), dtype=bool)
self.x = np.random.uniform(-2, 3, self.shape).astype(np.float32)
self.y = np.random.uniform(-2, 3, self.shape).astype(np.float32)
self.x = np.random.uniform((-2), 3, self.shape).astype(np.float32)
self.y = np.random.uniform((-2), 3, self.shape).astype(np.float32)
self.out = np.where(self.cond, self.x, self.y)
def ref_x_backward(self, dout):
return np.where(self.cond == True, dout, 0)
return np.where((self.cond == True), dout, 0)
def ref_y_backward(self, dout):
return np.where(self.cond == False, dout, 0)
return np.where((self.cond == False), dout, 0)
def test_api(self, use_cuda=False):
for x_stop_gradient in [False, True]:
......@@ -90,17 +90,17 @@ class TestWhereAPI(unittest.TestCase):
y.stop_gradient = y_stop_gradient
result = paddle.where(cond, x, y)
append_backward(layers.mean(result))
for use_cuda in [False, True]:
if use_cuda and not fluid.core.is_compiled_with_cuda():
if (use_cuda and
(not fluid.core.is_compiled_with_cuda())):
break
place = fluid.CUDAPlace(
0) if use_cuda else fluid.CPUPlace()
place = (fluid.CUDAPlace(0)
if use_cuda else fluid.CPUPlace())
exe = fluid.Executor(place)
fetch_list = [result, result.grad_name]
if x_stop_gradient is False:
if (x_stop_gradient is False):
fetch_list.append(x.grad_name)
if y_stop_gradient is False:
if (y_stop_gradient is False):
fetch_list.append(y.grad_name)
out = exe.run(
fluid.default_main_program(),
......@@ -109,13 +109,13 @@ class TestWhereAPI(unittest.TestCase):
'y': self.y},
fetch_list=fetch_list)
assert np.array_equal(out[0], self.out)
if x_stop_gradient is False:
if (x_stop_gradient is False):
assert np.array_equal(out[2],
self.ref_x_backward(out[1]))
if y.stop_gradient is False:
if (y.stop_gradient is False):
assert np.array_equal(
out[3], self.ref_y_backward(out[1]))
elif y.stop_gradient is False:
elif (y.stop_gradient is False):
assert np.array_equal(out[2],
self.ref_y_backward(out[1]))
......@@ -124,44 +124,38 @@ class TestWhereAPI(unittest.TestCase):
with fluid.program_guard(main_program):
x = fluid.layers.data(name='x', shape=[4, 1], dtype='float32')
y = fluid.layers.data(name='y', shape=[4, 2], dtype='float32')
x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32")
y_i = np.array([[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0]]).astype("float32")
result = paddle.where(x > 1, x=x, y=y)
x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype('float32')
y_i = np.array(
[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype('float32')
result = paddle.where((x > 1), x=x, y=y)
for use_cuda in [False, True]:
if use_cuda and not fluid.core.is_compiled_with_cuda():
if (use_cuda and (not fluid.core.is_compiled_with_cuda())):
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
place = (fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace())
exe = fluid.Executor(place)
out = exe.run(fluid.default_main_program(),
feed={'x': x_i,
'y': y_i},
fetch_list=[result])
assert np.array_equal(out[0], np.where(x_i > 1, x_i, y_i))
assert np.array_equal(out[0], np.where((x_i > 1), x_i, y_i))
def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape):
paddle.enable_static()
main_program = Program()
with fluid.program_guard(main_program):
cond = fluid.layers.data(
name='cond', shape=cond_shape, dtype='bool')
x = fluid.layers.data(name='x', shape=x_shape, dtype='float32')
y = fluid.layers.data(name='y', shape=y_shape, dtype='float32')
cond_data_tmp = np.random.random(size=cond_shape).astype("float32")
cond_data = cond_data_tmp < 0.3
x_data = np.random.random(size=x_shape).astype("float32")
y_data = np.random.random(size=y_shape).astype("float32")
cond_data_tmp = np.random.random(size=cond_shape).astype('float32')
cond_data = (cond_data_tmp < 0.3)
x_data = np.random.random(size=x_shape).astype('float32')
y_data = np.random.random(size=y_shape).astype('float32')
result = paddle.where(condition=cond, x=x, y=y)
for use_cuda in [False, True]:
if use_cuda and not fluid.core.is_compiled_with_cuda():
if (use_cuda and (not fluid.core.is_compiled_with_cuda())):
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
place = (fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace())
exe = fluid.Executor(place)
out = exe.run(
fluid.default_main_program(),
......@@ -169,9 +163,7 @@ class TestWhereAPI(unittest.TestCase):
'x': x_data,
'y': y_data},
fetch_list=[result])
expect = np.where(cond_data, x_data, y_data)
assert np.array_equal(out[0], expect)
def test_static_api_broadcast_1(self):
......@@ -198,28 +190,24 @@ class TestWhereAPI(unittest.TestCase):
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_static_api_broadcast_5(self):
cond_shape = [3, 2, 2, 4]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_static_api_broadcast_6(self):
cond_shape = [2, 2, 4]
a_shape = [2, 2, 1]
b_shape = [2, 2, 1]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_static_api_broadcast_7(self):
cond_shape = [2, 2, 4]
a_shape = [2, 1, 4]
b_shape = [2, 1, 4]
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_static_api_broadcast_8(self):
cond_shape = [3, 2, 2, 4]
a_shape = [2, 2, 1]
......@@ -230,9 +218,9 @@ class TestWhereAPI(unittest.TestCase):
class TestWhereDygraphAPI(unittest.TestCase):
def test_api(self):
with fluid.dygraph.guard():
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64")
cond_i = np.array([False, False, True, True]).astype("bool")
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype('float64')
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype('float64')
cond_i = np.array([False, False, True, True]).astype('bool')
x = fluid.dygraph.to_variable(x_i)
y = fluid.dygraph.to_variable(y_i)
cond = fluid.dygraph.to_variable(cond_i)
......@@ -242,15 +230,12 @@ class TestWhereDygraphAPI(unittest.TestCase):
def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape):
with fluid.dygraph.guard():
cond_tmp = paddle.rand(cond_shape)
cond = cond_tmp < 0.3
cond = (cond_tmp < 0.3)
a = paddle.rand(a_shape)
b = paddle.rand(b_shape)
result = paddle.where(cond, a, b)
result = result.numpy()
expect = np.where(cond, a, b)
self.assertTrue(np.array_equal(expect, result))
def test_dygraph_api_broadcast_1(self):
......@@ -277,28 +262,24 @@ class TestWhereDygraphAPI(unittest.TestCase):
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_dygraph_api_broadcast_5(self):
cond_shape = [3, 2, 2, 4]
a_shape = [2, 2, 4]
b_shape = [2, 2, 4]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_dygraph_api_broadcast_6(self):
cond_shape = [2, 2, 4]
a_shape = [2, 2, 1]
b_shape = [2, 2, 1]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_dygraph_api_broadcast_7(self):
cond_shape = [2, 2, 4]
a_shape = [2, 1, 4]
b_shape = [2, 1, 4]
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
# @Note Now, maybe not compatibility with old version
def test_dygraph_api_broadcast_8(self):
cond_shape = [3, 2, 2, 4]
a_shape = [2, 2, 1]
......@@ -308,40 +289,50 @@ class TestWhereDygraphAPI(unittest.TestCase):
def test_where_condition(self):
data = np.array([[True, False], [False, True]])
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 2])
x = fluid.layers.data(name='x', shape=[(-1), 2])
y = paddle.where(x)
self.assertEqual(type(y), tuple)
self.assertEqual(len(y), 2)
z = fluid.layers.concat(list(y), axis=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': data},
(res, ) = exe.run(feed={'x': data},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[0, 0], [1, 1]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
data = np.array([True, True, False])
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1])
x = fluid.layers.data(name='x', shape=[(-1)])
y = paddle.where(x)
self.assertEqual(type(y), tuple)
self.assertEqual(len(y), 1)
z = fluid.layers.concat(list(y), axis=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': data},
(res, ) = exe.run(feed={'x': data},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[0], [1]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
def test_eager(self):
with _test_eager_guard():
self.test_api()
self.test_dygraph_api_broadcast_1()
self.test_dygraph_api_broadcast_2()
self.test_dygraph_api_broadcast_3()
self.test_dygraph_api_broadcast_4()
self.test_dygraph_api_broadcast_5()
self.test_dygraph_api_broadcast_6()
self.test_dygraph_api_broadcast_7()
self.test_dygraph_api_broadcast_8()
class TestWhereOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64")
cond_i = np.array([False, False, True, True]).astype("bool")
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype('float64')
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype('float64')
cond_i = np.array([False, False, True, True]).astype('bool')
def test_Variable():
paddle.where(cond_i, x_i, y_i)
......@@ -360,10 +351,14 @@ class TestWhereOpError(unittest.TestCase):
with fluid.dygraph.guard():
cond_shape = [2, 2, 4]
cond_tmp = paddle.rand(cond_shape)
cond = cond_tmp < 0.3
cond = (cond_tmp < 0.3)
a = paddle.rand(cond_shape)
self.assertRaises(ValueError, paddle.where, cond, a)
def test_eager(self):
with _test_eager_guard():
self.test_value_error()
if __name__ == '__main__':
if (__name__ == '__main__'):
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,23 +13,22 @@
# limitations under the License.
from __future__ import division
import unittest
import numpy as np
from op_test import OpTest
import paddle
from paddle.fluid import core
from paddle.fluid.framework import _test_eager_guard
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-1.0 * x))
return (1.0 / (1.0 + np.exp(((-1.0) * x))))
def YoloBox(x, img_size, attrs):
n, c, h, w = x.shape
(n, c, h, w) = x.shape
anchors = attrs['anchors']
an_num = int(len(anchors) // 2)
an_num = int((len(anchors) // 2))
class_num = attrs['class_num']
conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample']
......@@ -37,60 +36,56 @@ def YoloBox(x, img_size, attrs):
scale_x_y = attrs['scale_x_y']
iou_aware = attrs['iou_aware']
iou_aware_factor = attrs['iou_aware_factor']
bias_x_y = -0.5 * (scale_x_y - 1.)
input_h = downsample * h
input_w = downsample * w
bias_x_y = ((-0.5) * (scale_x_y - 1.0))
input_h = (downsample * h)
input_w = (downsample * w)
if iou_aware:
ioup = x[:, :an_num, :, :]
ioup = np.expand_dims(ioup, axis=-1)
ioup = np.expand_dims(ioup, axis=(-1))
x = x[:, an_num:, :, :]
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
x = x.reshape((n, an_num, (5 + class_num), h, w)).transpose((0, 1, 3, 4, 2))
pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (
grid_x + sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y + bias_x_y) / w
pred_box[:, :, :, :, 1] = (
grid_y + sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
pred_box[:, :, :, :, 0] = ((
(grid_x + (sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y)) + bias_x_y) /
w)
pred_box[:, :, :, :, 1] = ((
(grid_y + (sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y)) + bias_x_y) /
h)
anchors = [(anchors[i], anchors[(i + 1)])
for i in range(0, len(anchors), 2)]
anchors_s = np.array(
[(an_w / input_w, an_h / input_h) for an_w, an_h in anchors])
[((an_w / input_w), (an_h / input_h)) for (an_w, an_h) in anchors])
anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1))
anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1))
pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w
pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h
pred_box[:, :, :, :, 2] = (np.exp(pred_box[:, :, :, :, 2]) * anchor_w)
pred_box[:, :, :, :, 3] = (np.exp(pred_box[:, :, :, :, 3]) * anchor_h)
if iou_aware:
pred_conf = sigmoid(x[:, :, :, :, 4:5])**(
1 - iou_aware_factor) * sigmoid(ioup)**iou_aware_factor
pred_conf = ((sigmoid(x[:, :, :, :, 4:5])**(1 - iou_aware_factor)) *
(sigmoid(ioup)**iou_aware_factor))
else:
pred_conf = sigmoid(x[:, :, :, :, 4:5])
pred_conf[pred_conf < conf_thresh] = 0.
pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf
pred_box = pred_box * (pred_conf > 0.).astype('float32')
pred_box = pred_box.reshape((n, -1, 4))
pred_box[:, :, :2], pred_box[:, :, 2:4] = \
pred_box[:, :, :2] - pred_box[:, :, 2:4] / 2., \
pred_box[:, :, :2] + pred_box[:, :, 2:4] / 2.0
pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis]
pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis]
pred_conf[(pred_conf < conf_thresh)] = 0.0
pred_score = (sigmoid(x[:, :, :, :, 5:]) * pred_conf)
pred_box = (pred_box * (pred_conf > 0.0).astype('float32'))
pred_box = pred_box.reshape((n, (-1), 4))
(pred_box[:, :, :2], pred_box[:, :, 2:4]) = (
(pred_box[:, :, :2] - (pred_box[:, :, 2:4] / 2.0)),
(pred_box[:, :, :2] + (pred_box[:, :, 2:4] / 2.0)))
pred_box[:, :, 0] = (pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis])
pred_box[:, :, 1] = (pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis])
pred_box[:, :, 2] = (pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis])
pred_box[:, :, 3] = (pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis])
if clip_bbox:
for i in range(len(pred_box)):
pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf)
pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf)
pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf,
img_size[i, 1] - 1)
pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf,
img_size[i, 0] - 1)
return pred_box, pred_score.reshape((n, -1, class_num))
pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], (-np.inf),
(img_size[(i, 1)] - 1))
pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], (-np.inf),
(img_size[(i, 0)] - 1))
return (pred_box, pred_score.reshape((n, (-1), class_num)))
class TestYoloBoxOp(OpTest):
......@@ -99,42 +94,35 @@ class TestYoloBoxOp(OpTest):
self.op_type = 'yolo_box'
x = np.random.random(self.x_shape).astype('float32')
img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32')
self.attrs = {
"anchors": self.anchors,
"class_num": self.class_num,
"conf_thresh": self.conf_thresh,
"downsample": self.downsample,
"clip_bbox": self.clip_bbox,
"scale_x_y": self.scale_x_y,
"iou_aware": self.iou_aware,
"iou_aware_factor": self.iou_aware_factor
}
self.inputs = {
'X': x,
'ImgSize': img_size,
}
boxes, scores = YoloBox(x, img_size, self.attrs)
self.outputs = {
"Boxes": boxes,
"Scores": scores,
'anchors': self.anchors,
'class_num': self.class_num,
'conf_thresh': self.conf_thresh,
'downsample': self.downsample,
'clip_bbox': self.clip_bbox,
'scale_x_y': self.scale_x_y,
'iou_aware': self.iou_aware,
'iou_aware_factor': self.iou_aware_factor
}
self.inputs = {'X': x, 'ImgSize': img_size}
(boxes, scores) = YoloBox(x, img_size, self.attrs)
self.outputs = {'Boxes': boxes, 'Scores': scores}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13,
13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.
self.scale_x_y = 1.0
self.iou_aware = False
self.iou_aware_factor = 0.5
......@@ -142,15 +130,16 @@ class TestYoloBoxOp(OpTest):
class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = False
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13,
13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.
self.scale_x_y = 1.0
self.iou_aware = False
self.iou_aware_factor = 0.5
......@@ -158,13 +147,14 @@ class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
class TestYoloBoxOpScaleXY(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13,
13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.2
self.iou_aware = False
......@@ -174,15 +164,16 @@ class TestYoloBoxOpScaleXY(TestYoloBoxOp):
class TestYoloBoxOpIoUAware(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (self.batch_size, an_num * (6 + self.class_num), 13, 13)
self.x_shape = (self.batch_size, (an_num * (6 + self.class_num)), 13,
13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.
self.scale_x_y = 1.0
self.iou_aware = True
self.iou_aware_factor = 0.5
......@@ -192,10 +183,9 @@ class TestYoloBoxDygraph(unittest.TestCase):
paddle.disable_static()
img_size = np.ones((2, 2)).astype('int32')
img_size = paddle.to_tensor(img_size)
x1 = np.random.random([2, 14, 8, 8]).astype('float32')
x1 = paddle.to_tensor(x1)
boxes, scores = paddle.vision.ops.yolo_box(
(boxes, scores) = paddle.vision.ops.yolo_box(
x1,
img_size=img_size,
anchors=[10, 13, 16, 30],
......@@ -203,12 +193,11 @@ class TestYoloBoxDygraph(unittest.TestCase):
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.)
assert boxes is not None and scores is not None
scale_x_y=1.0)
assert ((boxes is not None) and (scores is not None))
x2 = np.random.random([2, 16, 8, 8]).astype('float32')
x2 = paddle.to_tensor(x2)
boxes, scores = paddle.vision.ops.yolo_box(
(boxes, scores) = paddle.vision.ops.yolo_box(
x2,
img_size=img_size,
anchors=[10, 13, 16, 30],
......@@ -216,18 +205,21 @@ class TestYoloBoxDygraph(unittest.TestCase):
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.,
scale_x_y=1.0,
iou_aware=True,
iou_aware_factor=0.5)
paddle.enable_static()
def test_eager(self):
with _test_eager_guard():
self.test_dygraph()
class TestYoloBoxStatic(unittest.TestCase):
def test_static(self):
x1 = paddle.static.data('x1', [2, 14, 8, 8], 'float32')
img_size = paddle.static.data('img_size', [2, 2], 'int32')
boxes, scores = paddle.vision.ops.yolo_box(
(boxes, scores) = paddle.vision.ops.yolo_box(
x1,
img_size=img_size,
anchors=[10, 13, 16, 30],
......@@ -235,11 +227,10 @@ class TestYoloBoxStatic(unittest.TestCase):
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.)
assert boxes is not None and scores is not None
scale_x_y=1.0)
assert ((boxes is not None) and (scores is not None))
x2 = paddle.static.data('x2', [2, 16, 8, 8], 'float32')
boxes, scores = paddle.vision.ops.yolo_box(
(boxes, scores) = paddle.vision.ops.yolo_box(
x2,
img_size=img_size,
anchors=[10, 13, 16, 30],
......@@ -247,27 +238,27 @@ class TestYoloBoxStatic(unittest.TestCase):
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.,
scale_x_y=1.0,
iou_aware=True,
iou_aware_factor=0.5)
assert boxes is not None and scores is not None
assert ((boxes is not None) and (scores is not None))
class TestYoloBoxOpHW(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = False
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 9)
self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, 9)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.
self.scale_x_y = 1.0
self.iou_aware = False
self.iou_aware_factor = 0.5
if __name__ == "__main__":
if (__name__ == '__main__'):
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,13 +13,13 @@
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle import zeros_like
from paddle.fluid import core, Program, program_guard
from paddle.fluid.framework import _test_eager_guard
class TestZerosLikeAPIError(unittest.TestCase):
......@@ -28,6 +28,10 @@ class TestZerosLikeAPIError(unittest.TestCase):
x = paddle.fluid.data('x', [3, 4])
self.assertRaises(TypeError, zeros_like, x, 'int8')
def test_eager(self):
with _test_eager_guard():
self.test_errors()
class TestZerosLikeAPI(unittest.TestCase):
def test_api(self):
......@@ -36,46 +40,48 @@ class TestZerosLikeAPI(unittest.TestCase):
train_program = Program()
with program_guard(train_program, startup_program):
x = paddle.fluid.data('X', shape)
# 'bool', 'float32', 'float64', 'int32', 'int64'
out1 = zeros_like(x)
out2 = zeros_like(x, np.bool)
out3 = zeros_like(x, 'float64')
out4 = zeros_like(x, 'int32')
out5 = zeros_like(x, 'int64')
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
place = (fluid.CUDAPlace(0)
if core.is_compiled_with_cuda() else fluid.CPUPlace())
exe = fluid.Executor(place)
outs = exe.run(train_program,
feed={'X': np.ones(shape).astype('float32')},
fetch_list=[out1, out2, out3, out4, out5])
for i, dtype in enumerate(
for (i, dtype) in enumerate(
[np.float32, np.bool, np.float64, np.int32, np.int64]):
self.assertEqual(outs[i].dtype, dtype)
self.assertEqual((outs[i] == np.zeros(shape, dtype)).all(), True)
def test_eager(self):
with _test_eager_guard():
self.test_api()
class TestZerosLikeImpeartive(unittest.TestCase):
def test_out(self):
shape = [3, 4]
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
place = (fluid.CUDAPlace(0)
if core.is_compiled_with_cuda() else fluid.CPUPlace())
paddle.disable_static(place)
x = paddle.to_tensor(np.ones(shape))
for dtype in [np.bool, np.float32, np.float64, np.int32, np.int64]:
out = zeros_like(x, dtype)
self.assertEqual((out.numpy() == np.zeros(shape, dtype)).all(),
True)
out = paddle.tensor.zeros_like(x)
self.assertEqual((out.numpy() == np.zeros(shape, dtype)).all(), True)
out = paddle.tensor.creation.zeros_like(x)
self.assertEqual((out.numpy() == np.zeros(shape, dtype)).all(), True)
paddle.enable_static()
def test_eager(self):
with _test_eager_guard():
self.test_out()
if __name__ == "__main__":
if (__name__ == '__main__'):
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,56 +13,55 @@
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.framework import _test_eager_guard
class TestZerosOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input dtype of zeros_op must be bool, float16, float32, float64, int32, int64.
shape = [4]
dtype = "int8"
dtype = 'int8'
self.assertRaises(TypeError, fluid.layers.zeros, shape, dtype)
def test_eager(self):
with _test_eager_guard():
self.test_errors()
class ApiZerosTest(unittest.TestCase):
def test_out(self):
with program_guard(Program()):
zeros = paddle.zeros(shape=[10], dtype="float64")
zeros = paddle.zeros(shape=[10], dtype='float64')
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="float64")
(result, ) = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype='float64')
self.assertEqual((result == expected_result).all(), True)
with paddle.static.program_guard(Program()):
zeros = paddle.zeros(shape=[10], dtype="int64")
zeros = paddle.zeros(shape=[10], dtype='int64')
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="int64")
(result, ) = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype='int64')
self.assertEqual((result == expected_result).all(), True)
with program_guard(Program()):
zeros = paddle.zeros(shape=[10], dtype="int64")
zeros = paddle.zeros(shape=[10], dtype='int64')
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="int64")
(result, ) = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype='int64')
self.assertEqual((result == expected_result).all(), True)
with program_guard(Program()):
out_np = np.zeros(shape=(1), dtype='float32')
out = paddle.zeros(shape=[1], dtype="float32")
out_np = np.zeros(shape=1, dtype='float32')
out = paddle.zeros(shape=[1], dtype='float32')
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
result = exe.run(fetch_list=[out])
......@@ -70,28 +69,37 @@ class ApiZerosTest(unittest.TestCase):
def test_fluid_out(self):
with program_guard(Program()):
zeros = fluid.layers.zeros(shape=[10], dtype="int64")
zeros = fluid.layers.zeros(shape=[10], dtype='int64')
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="int64")
(result, ) = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype='int64')
self.assertEqual((result == expected_result).all(), True)
def test_eager(self):
with _test_eager_guard():
self.test_out()
self.test_fluid_out()
class ApiZerosError(unittest.TestCase):
def test_errors(self):
def test_error1():
with paddle.static.program_guard(fluid.Program()):
ones = fluid.layers.zeros(shape=10, dtype="int64")
ones = fluid.layers.zeros(shape=10, dtype='int64')
self.assertRaises(TypeError, test_error1)
def test_error2():
with paddle.static.program_guard(fluid.Program()):
ones = fluid.layers.zeros(shape=[10], dtype="int8")
ones = fluid.layers.zeros(shape=[10], dtype='int8')
self.assertRaises(TypeError, test_error2)
def test_eager(self):
with _test_eager_guard():
self.test_errors()
if __name__ == "__main__":
if (__name__ == '__main__'):
unittest.main()
......@@ -17,6 +17,7 @@ from __future__ import print_function
import inspect
import os
import fcntl
import numpy as np
import paddle
import paddle.fluid.core as core
......@@ -29,28 +30,61 @@ type_dict_paddle_to_str = {
paddle.int32: 'int32',
paddle.int64: 'int64',
paddle.float16: 'float16',
paddle.bfloat16: 'bfloat16',
paddle.float32: 'float32',
paddle.float64: 'float64',
paddle.complex128: 'complex128',
paddle.complex64: 'complex64',
}
type_dict_paddle_to_numpy = {
paddle.bool: np.bool_,
paddle.uint8: np.uint8,
paddle.int8: np.int8,
paddle.int16: np.int16,
paddle.int32: np.int32,
paddle.int64: np.int64,
paddle.bfloat16: np.uint16,
paddle.float16: np.float16,
paddle.float32: np.float32,
paddle.float64: np.float64,
paddle.complex128: np.complex128,
paddle.complex64: np.complex64,
}
type_dict_str_to_paddle = {
'uint8': paddle.uint8,
'int8': paddle.int8,
'int16': paddle.int16,
'int32': paddle.int32,
'int64': paddle.int64,
'float32': paddle.float32,
'bfloat16': paddle.bfloat16,
'float16': paddle.float16,
'float32': paddle.float32,
'float64': paddle.float64,
'bool': paddle.bool,
'uint8': paddle.uint8,
'int8': paddle.int8,
'complex128': paddle.complex128,
'complex64': paddle.complex64,
'int16': paddle.int16,
'complex128': paddle.complex128,
}
type_dict_str_to_numpy = {
'uint8': np.uint8,
'int8': np.int8,
'int16': np.int16,
'int32': np.int32,
'int64': np.int64,
'bfloat16': np.uint16,
'float16': np.float16,
'float32': np.float32,
'float64': np.float64,
'bool': np.bool_,
'complex64': np.complex64,
'complex128': np.complex128,
}
xpu_test_op_white_list = []
xpu_test_type_white_list = []
xpu_test_op_type_white_list = []
xpu_test_op_type_white_list = ['float64']
xpu_test_device_op_white_list = []
xpu_test_device_op_type_white_list = []
......@@ -122,6 +156,8 @@ def make_xpu_op_list(xpu_version):
if op_name in op_white_list or device_op_name in device_op_white_list:
continue
for op_type in type_list:
if op_type == paddle.bfloat16:
op_type = paddle.bfloat16
if op_type in type_white_list or op_type not in type_dict_paddle_to_str.keys(
):
continue
......@@ -143,10 +179,17 @@ def get_xpu_op_support_types(op_name, dev_id=0):
xpu_version = core.get_xpu_device_version(dev_id)
support_type_list = core.get_xpu_device_op_support_types(op_name,
xpu_version)
support_type_str_list = [
type_dict_paddle_to_str[x] for x in support_type_list
support_type_str_list = []
for stype in support_type_list:
if stype == paddle.bfloat16:
support_type_str_list.append(type_dict_paddle_to_str[
paddle.bfloat16])
else:
support_type_str_list.append(type_dict_paddle_to_str[stype])
type_white_list = get_op_type_white_list()
return [
stype for stype in support_type_str_list if stype not in type_white_list
]
return support_type_str_list
def record_op_test(op_name, test_type):
......@@ -196,8 +239,9 @@ def create_test_class(func_globals,
continue
class_obj = test_class[1]
cls_name = "{0}_{1}".format(test_class[0], str(test_type))
func_globals[cls_name] = type(cls_name, (class_obj, ),
{'in_type': test_type})
func_globals[cls_name] = type(
cls_name, (class_obj, ),
{'in_type': type_dict_str_to_numpy[test_type]})
if hasattr(test_class_obj, 'use_dynamic_create_class'
) and test_class_obj.use_dynamic_create_class:
......@@ -205,7 +249,7 @@ def create_test_class(func_globals,
for dy_class in dynamic_classes:
cls_name = "{0}_{1}".format(dy_class[0], str(test_type))
attr_dict = dy_class[1]
attr_dict['in_type'] = test_type
attr_dict['in_type'] = type_dict_str_to_numpy[test_type]
func_globals[cls_name] = type(cls_name, (base_class, ), attr_dict)
record_op_test(op_name, test_type)
......
......@@ -24,23 +24,41 @@ from op_test_xpu import OpTest, XPUOpTest
import paddle
from paddle.fluid import Program, program_guard
import op_test
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
class TestClipOp(XPUOpTest):
def set_xpu(self):
self.__class__.use_xpu = True
self.place = paddle.XPUPlace(0)
class XPUTestClipOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'clip'
self.use_dynamic_create_class = False
class TestClipOp(XPUOpTest):
def setUp(self):
self.init_dtype()
self.set_xpu()
self.max_relative_error = 0.006
self.op_type = "clip"
self.place = paddle.XPUPlace(0)
self.inputs = {}
self.initTestCase()
self.init_data()
self.set_attrs()
self.set_inputs()
self.outputs = {
'Out': np.clip(self.inputs['X'], self.min_v, self.max_v)
}
self.op_type = "clip"
self.attrs = {}
self.attrs['min'] = self.min
self.attrs['max'] = self.max
def set_xpu(self):
self.__class__.use_xpu = True
self.__class__.no_need_check_grad = True
self.__class__.op_type = self.dtype
def init_data(self):
self.shape = (4, 10, 10)
self.max = 0.8
self.min = 0.3
def set_inputs(self):
if 'Min' in self.inputs:
min_v = self.inputs['Min']
else:
......@@ -51,62 +69,55 @@ class TestClipOp(XPUOpTest):
else:
max_v = self.attrs['max']
self.min_v = min_v
self.max_v = max_v
self.max_relative_error = 0.006
input = np.random.random(self.shape).astype("float32")
input[np.abs(input - min_v) < self.max_relative_error] = 0.5
input[np.abs(input - max_v) < self.max_relative_error] = 0.5
self.inputs['X'] = input
self.outputs = {'Out': np.clip(self.inputs['X'], min_v, max_v)}
def set_attrs(self):
self.attrs = {}
self.attrs['min'] = self.min
self.attrs['max'] = self.max
def init_dtype(self):
self.dtype = self.in_type
def test_check_output(self):
paddle.enable_static()
self.check_output_with_place(self.place)
paddle.disable_static()
def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad_with_place(self.place, ['X'], 'Out')
paddle.disable_static()
def initTestCase(self):
self.shape = (4, 10, 10)
self.max = 0.8
self.min = 0.3
self.inputs['Max'] = np.array([0.8]).astype('float32')
self.inputs['Min'] = np.array([0.1]).astype('float32')
class TestCase1(TestClipOp):
def initTestCase(self):
class TestClipOp1(TestClipOp):
def init_data(self):
self.shape = (8, 16, 8)
self.max = 0.7
self.min = 0.0
class TestCase2(TestClipOp):
def initTestCase(self):
class TestClipOp2(TestClipOp):
def init_data(self):
self.shape = (8, 16)
self.max = 1.0
self.min = 0.0
class TestCase3(TestClipOp):
def initTestCase(self):
class TestClipOp3(TestClipOp):
def init_data(self):
self.shape = (4, 8, 16)
self.max = 0.7
self.min = 0.2
class TestCase4(TestClipOp):
def initTestCase(self):
class TestClipOp4(TestClipOp):
def init_data(self):
self.shape = (4, 8, 8)
self.max = 0.7
self.min = 0.2
self.inputs['Max'] = np.array([0.8]).astype('float32')
self.inputs['Min'] = np.array([0.3]).astype('float32')
class TestCase5(TestClipOp):
def initTestCase(self):
class TestClipOp5(TestClipOp):
def init_data(self):
self.shape = (4, 8, 16)
self.max = 0.5
self.min = 0.5
......@@ -212,5 +223,9 @@ class TestInplaceClipAPI(TestClipAPI):
return x.clip_(min, max)
support_types = get_xpu_op_support_types('clip')
for stype in support_types:
create_test_class(globals(), XPUTestClipOp, stype)
if __name__ == '__main__':
unittest.main()
......@@ -18,251 +18,140 @@ import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
from op_test_xpu import XPUOpTest
import paddle.fluid as fluid
import paddle
def gather_nd_grad(x, index):
dout_shape = index.shape[:-1] + x.shape[index.shape[-1]:]
numel = 1
for i in dout_shape:
numel = numel * i
dout = np.full(dout_shape, 1. / numel)
dx = np.full_like(x, 0)
index = tuple(index.reshape(-1, index.shape[-1]).T)
np.add.at(dx, index, dout)
return dx
def test_class1(op_type, typename):
class TestGatherNdOpWithEmptyIndex(XPUOpTest):
#Index has empty element, which means copy entire tensor
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
xnp = np.random.random((5, 20)).astype(typename)
self.inputs = {
'X': xnp,
'Index': np.array([[], []]).astype("int32")
}
self.outputs = {
'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :]))
}
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_1".format(op_type, typename)
TestGatherNdOpWithEmptyIndex.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithEmptyIndex
def test_class2(op_type, typename):
class TestGatherNdOpWithIndex1(OpTest):
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
xnp = np.random.random((5, 20)).astype(typename)
self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_2".format(op_type, typename)
TestGatherNdOpWithIndex1.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithIndex1
def test_class3(op_type, typename):
class TestGatherNdOpWithLowIndex(OpTest):
#Index has low rank, X has high rank
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
xnp = np.random.uniform(0, 100, (10, 10)).astype(typename)
index = np.array([[1], [2]]).astype("int64")
self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': xnp[tuple(index.T)]}
self.x_grad = gather_nd_grad(xnp, index)
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
import paddle
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
cls_name = "{0}_{1}_3".format(op_type, typename)
TestGatherNdOpWithLowIndex.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithLowIndex
paddle.enable_static()
def test_class4(op_type, typename):
class TestGatherNdOpIndex1(OpTest):
#Index has low rank, X has high rank
class XPUTestGatherNd(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'gather_nd'
class XPUTestGatherNdBase(XPUOpTest):
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
xnp = np.random.uniform(0, 100, (10, 10)).astype(typename)
index = np.array([1, 2]).astype("int64")
self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': xnp[tuple(index.T)]}
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_4".format(op_type, typename)
TestGatherNdOpIndex1.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpIndex1
def test_class5(op_type, typename):
class TestGatherNdOpWithSameIndexAsX(OpTest):
#Index has same rank as X's rank
def setUp(self):
self.set_xpu()
self.dtype = self.in_type
self.__class__.no_need_check_grad = True
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
xnp = np.random.uniform(0, 100, (10, 10)).astype(typename)
index = np.array([[1, 1], [2, 1]]).astype("int64")
self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': xnp[tuple(index.T)]} #[25, 22]
self.init_data()
def set_xpu(self):
self.__class__.use_xpu = True
self.inputs = {'X': self.xnp, 'Index': self.inp}
self.outputs = {'Out': self.output, }
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_5".format(op_type, typename)
TestGatherNdOpWithSameIndexAsX.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithSameIndexAsX
def test_class6(op_type, typename):
class TestGatherNdOpWithHighRankSame(OpTest):
#Both Index and X have high rank, and Rank(Index) = Rank(X)
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
def init_data(self):
self.xnp = np.random.random((5, 20)).astype(self.in_type)
self.inp = np.array([[], []]).astype("int32")
self.output = np.vstack(
(self.xnp[np.newaxis, :], self.xnp[np.newaxis, :]))
class XPUTestGatherNdOpWithEmptyIndex1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.random((5, 20)).astype(self.in_type)
self.inp = np.array([[], []]).astype("int32")
self.output = np.vstack(
(self.xnp[np.newaxis, :], self.xnp[np.newaxis, :]))
class XPUTestGatherNdOpWithEmptyIndex2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.random((5, 20)).astype(self.in_type)
self.inp = np.array([[], []]).astype("int64")
self.output = np.vstack(
(self.xnp[np.newaxis, :], self.xnp[np.newaxis, :]))
class XPUTestGatherNdOpWithIndex1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.random((5, 20)).astype(self.in_type)
self.inp = np.array([1]).astype("int32")
self.output = self.xnp[self.inp]
class XPUTestGatherNdOpWithIndex2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.random((5, 20)).astype(self.in_type)
self.inp = np.array([1]).astype("int64")
self.output = self.xnp[self.inp]
class XPUTestGatherNdOpWithLowIndex1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([[1], [2]]).astype("int32")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpWithLowIndex2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([1, 2]).astype("int64")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpWithHighRankSame1(XPUTestGatherNdBase):
def init_data(self):
shape = (5, 2, 3, 1, 10)
xnp = np.random.rand(*shape).astype(typename)
index = np.vstack([np.random.randint(
0, s, size=2) for s in shape]).T
self.inputs = {'X': xnp, 'Index': index.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)]}
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_6".format(op_type, typename)
TestGatherNdOpWithHighRankSame.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithHighRankSame
self.xnp = np.random.rand(*shape).astype(self.in_type)
self.inp = np.vstack(
[np.random.randint(
0, s, size=2) for s in shape]).T.astype("int32")
self.output = self.xnp[tuple(self.inp.T)]
def test_class7(op_type, typename):
class TestGatherNdOpWithHighRankDiff(OpTest):
#Both Index and X have high rank, Rank(Index) < Rank(X)
class XPUTestGatherNdOpWithHighRankSame2(XPUTestGatherNdBase):
def init_data(self):
shape = (5, 2, 3, 1, 10)
self.xnp = np.random.rand(*shape).astype(self.in_type)
self.inp = np.vstack(
[np.random.randint(
0, s, size=2) for s in shape]).T.astype("int64")
self.output = self.xnp[tuple(self.inp.T)]
def setUp(self):
self.set_xpu()
self.place = paddle.XPUPlace(0)
self.op_type = "gather_nd"
class XPUTestGatherNdOpWithHighRankDiff1(XPUTestGatherNdBase):
def init_data(self):
shape = (2, 3, 4, 1, 10)
xnp = np.random.rand(*shape).astype(typename)
index = np.vstack(
self.xnp = np.random.rand(*shape).astype(self.in_type)
self.inp = np.vstack(
[np.random.randint(
0, s, size=200) for s in shape]).T
index_re = index.reshape([20, 5, 2, 5])
self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])}
0, s, size=200) for s in shape]).T.astype("int32")
self.output = self.xnp[tuple(self.inp.T)]
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
cls_name = "{0}_{1}_7".format(op_type, typename)
TestGatherNdOpWithHighRankDiff.__name__ = cls_name
globals()[cls_name] = TestGatherNdOpWithHighRankDiff
class TestGatherNdAPI(unittest.TestCase):
def test_imperative(self):
paddle.disable_static()
input_1 = np.array([[1, 2], [3, 4], [5, 6]])
index_1 = np.array([[1]])
input = fluid.dygraph.to_variable(input_1)
index = fluid.dygraph.to_variable(index_1)
output = paddle.fluid.layers.gather(input, index)
output_np = output.numpy()
expected_output = np.array([3, 4])
self.assertTrue(np.allclose(output_np, expected_output))
paddle.enable_static()
for _typename in {'float32', 'int', 'int64'}:
test_class1('gather_nd', _typename)
test_class2('gather_nd', _typename)
test_class3('gather_nd', _typename)
test_class4('gather_nd', _typename)
test_class5('gather_nd', _typename)
test_class6('gather_nd', _typename)
test_class7('gather_nd', _typename)
class XPUTestGatherNdOpWithHighRankDiff2(XPUTestGatherNdBase):
def init_data(self):
shape = (2, 3, 4, 1, 10)
self.xnp = np.random.rand(*shape).astype(self.in_type)
self.inp = np.vstack(
[np.random.randint(
0, s, size=200) for s in shape]).T.astype("int64")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpWithSameIndexAsX1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([[1, 1], [2, 1]]).astype("int32")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpWithSameIndexAsX2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([[1, 1], [2, 1]]).astype("int64")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpIndex1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([1, 2]).astype("int32")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpIndex2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([1, 2]).astype("int64")
self.output = self.xnp[tuple(self.inp.T)]
support_types = get_xpu_op_support_types('gather_nd')
for stype in support_types:
create_test_class(globals(), XPUTestGatherNd, stype)
if __name__ == "__main__":
unittest.main()
......@@ -69,7 +69,7 @@ class XPUTestArgsortOp1(XPUOpTestWrapper):
self.descending = False if not hasattr(
self, 'init_descending') else self.init_descending
if self.in_type == 'float32':
if self.in_type == np.float32:
self.x = np.random.random(self.input_shape).astype(self.dtype)
else:
self.x = np.random.randint(
......@@ -118,7 +118,7 @@ class XPUTestArgsortOp2(XPUOpTestWrapper):
self.init_axis()
self.init_direction()
if self.in_type == 'float32':
if self.in_type == np.float32:
self.x = np.random.random(self.input_shape).astype(self.dtype)
else:
self.x = np.random.randint(
......
- backward_api : matmul_grad
forward : matmul (const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y) -> Tensor(out)
forward : matmul (const Tensor& x, const Tensor& y, bool transpose_x=false, bool transpose_y=false) -> Tensor(out)
args : (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册