未验证 提交 a3b3ec68 编写于 作者: J jianghaicheng 提交者: GitHub

add ipu_backend (#36322)

上级 a6d2fddb
......@@ -46,6 +46,7 @@ option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
option(WITH_WIN_DUMP_DBG "Compile with windows core dump debug mode" OFF)
option(WITH_ASCEND "Compile PaddlePaddle with ASCEND" OFF)
option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF)
option(WITH_IPU "Compile PaddlePaddle with Graphcore IPU" OFF)
# NOTE(zhiqiu): WITH_ASCEND_CL can be compile on x86_64, so we can set WITH_ASCEND=OFF and WITH_ASCEND_CL=ON
# to develop some acl related functionality on x86
option(WITH_ASCEND_CL "Compile PaddlePaddle with ASCEND CL" ${WITH_ASCEND})
......
......@@ -97,6 +97,11 @@ if(WITH_XPU)
add_definitions(-DPADDLE_WITH_XPU)
endif()
if(WITH_IPU)
message(STATUS "Compile with IPU!")
add_definitions(-DPADDLE_WITH_IPU)
endif()
if(WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA)
add_definitions(-DEIGEN_USE_GPU)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if(WITH_IPU)
set(POPLAR_DIR CACHE PATH "Path to a Poplar install")
set(POPART_DIR CACHE PATH "Path to a Popart install")
set(POPLAR_SDK_DIR CACHE PATH "Path to an extracted SDK archive or to a Poplar & Popart install directory (Will populate POPLAR_DIR and POPART_DIR)")
if(DEFINED ENV{POPLAR_SDK_DIR})
set(POPLAR_SDK_DIR $ENV{POPLAR_SDK_DIR})
execute_process(COMMAND find ${POPLAR_SDK_DIR}/ -maxdepth 1 -type d -name "popart*"
OUTPUT_VARIABLE POPART_DIR OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND find ${POPLAR_SDK_DIR}/ -maxdepth 1 -type d -name "poplar-*" -o -name "poplar"
OUTPUT_VARIABLE POPLAR_DIR OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT IS_DIRECTORY "${POPLAR_DIR}")
message(FATAL_ERROR "Couldn't find a \"poplar\" or \"poplar-*\" folder in '${POPLAR_SDK_DIR}'")
endif()
if(NOT IS_DIRECTORY "${POPART_DIR}")
message(FATAL_ERROR "Couldn't find a \"popart*\" folder in '${POPLAR_SDK_DIR}'")
endif()
else()
message(FATAL_ERROR "You must provide a path to a Poplar install using export POPLAR_SDK_DIR=/path/to/poplar_sdk")
endif()
message("POPLAR_DIR is ${POPLAR_DIR}")
message("POPART_DIR is ${POPART_DIR}")
if(EXISTS ${POPLAR_DIR})
list(APPEND CMAKE_PREFIX_PATH ${POPLAR_DIR})
set(ENABLE_POPLAR_CMD "source ${POPLAR_DIR}/enable.sh")
find_package(poplar REQUIRED)
include_directories("${POPLAR_DIR}/include")
link_directories("${POPLAR_DIR}/lib")
endif()
if(NOT poplar_FOUND)
message(FATAL_ERROR "You must provide a path to a Poplar install using -DPOPLAR_DIR=/path/to/popart/build/install")
endif()
if(EXISTS ${POPART_DIR})
list(APPEND CMAKE_PREFIX_PATH ${POPART_DIR})
set(ENABLE_POPART_CMD "source ${POPART_DIR}/enable.sh")
find_package(popart REQUIRED COMPONENTS popart-only)
include_directories("${POPART_DIR}/include")
link_directories("${POPART_DIR}/lib")
endif()
if(NOT popart_FOUND)
message(FATAL_ERROR "You must provide a path to a Popart build using -DPOPART_DIR=/path/to/popart/build")
endif()
add_definitions(-DONNX_NAMESPACE=onnx)
add_custom_target(extern_poplar DEPENDS poplar popart-only)
endif()
......@@ -204,6 +204,9 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
elseif(WITH_ASCEND_CL AND NOT WITH_ASCEND_CXX11)
SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git)
SET(PROTOBUF_TAG v3.8.0)
elseif(WITH_IPU)
SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git)
SET(PROTOBUF_TAG d750fbf648256c7c631f51ffdbf67d7c18b0114e)
else()
SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git)
SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546)
......@@ -243,6 +246,8 @@ ENDFUNCTION()
if(WITH_ASCEND OR WITH_ASCEND_CL)
SET(PROTOBUF_VERSION 3.8.0)
elseif(WITH_IPU)
SET(PROTOBUF_VERSION 3.6.1)
else()
SET(PROTOBUF_VERSION 3.1.0)
endif()
......
......@@ -151,6 +151,13 @@ set(COMMON_FLAGS
${fsanitize}
)
if(WITH_IPU)
set(COMMON_FLAGS ${COMMON_FLAGS}
-Wno-sign-compare # Warnings in Popart
-Wno-non-virtual-dtor # Warnings in Popart
)
endif()
if(NOT APPLE)
if((${CMAKE_CXX_COMPILER_VERSION} VERSION_GREATER 8.0) OR (WITH_ROCM))
set(COMMON_FLAGS
......
......@@ -391,4 +391,9 @@ if (WIN32)
list(APPEND third_party_deps extern_dirent)
endif (WIN32)
if (WITH_IPU)
include(external/poplar)
list(APPEND third_party_deps extern_poplar)
endif()
add_custom_target(third_party ALL DEPENDS ${third_party_deps})
......@@ -10,3 +10,8 @@ ENDIF()
IF(WITH_ASCEND OR WITH_ASCEND_CL)
add_subdirectory(npu)
ENDIF()
# IPU
IF(WITH_IPU)
add_subdirectory(ipu)
ENDIF()
cc_library(ipu_device SRCS device.cc DEPS enforce popart)
cc_library(ipu_utils SRCS ipu_utils.cc DEPS memory framework_proto popart)
cc_library(ipu_strategy SRCS ipu_strategy.cc DEPS popart graph framework_proto enforce)
cc_library(ipu_optimizer SRCS ipu_optimizer.cc DEPS popart enforce)
cc_library(ipu_executor SRCS ipu_executor.cc DEPS ipu_optimizer ipu_utils popart graph framework_proto)
cc_library(popart_canonicalization_utils SRCS ${POPART_CANONICALIZATION_SRC} DEPS framework_proto enforce ipu_utils)
cc_library(ipu_compiler SRCS ipu_compiler.cc DEPS popart graph ipu_utils graph_helper)
cc_library(ipu_backend SRCS ipu_backend.cc DEPS popart ipu_compiler graph framework_proto enforce ipu_utils ipu_strategy ipu_device ipu_executor graph_helper)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <popart/names.hpp>
namespace paddle {
namespace platform {
namespace ipu {
static constexpr const char *sIpuIndexAttr = "ipu_index";
static constexpr const char *sIpuStageAttr = "ipu_stage";
static constexpr const char *sOpIdentifyIdAttr = "op_identify_id";
static constexpr const char *sDebugInfoId = "__debug_info_id";
static constexpr const char *sBeta1 = "beta1";
static constexpr const char *sBeta2 = "beta2";
static constexpr const char *sBeta1Pow = "Beta1Pow";
static constexpr const char *sBeta2Pow = "Beta2Pow";
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/ipu/device.h"
namespace paddle {
namespace platform {
namespace ipu {
Device::Device(const popart::DeviceInfo& device_info)
: id_(device_info.getId()), is_attached_(device_info.isAttached()) {
popart::DeviceType popart_device_type = device_info.getType();
switch (popart_device_type) {
case popart::DeviceType::IpuModel:
device_type_ = DeviceType::IpuModel;
break;
case popart::DeviceType::Ipu:
device_type_ = DeviceType::Ipu;
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"popart::DeviceType:Unsupported type %d", popart_device_type));
}
}
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <popart/devicemanager.hpp>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace ipu {
enum class DeviceType { IpuModel = 0, Cpu, Ipu, OfflineIpu, Sim };
class Device {
public:
Device() {}
explicit Device(const popart::DeviceInfo& device_info);
int getId() const { return id_; }
bool isAttached() const { return is_attached_; }
DeviceType getType() const { return device_type_; }
private:
int id_;
bool is_attached_;
DeviceType device_type_;
/* TODO:: Add more elements in the future */
};
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/ipu/ipu_backend.h"
#include "paddle/fluid/platform/ipu/ipu_utils.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace platform {
namespace ipu {
std::shared_ptr<IpuBackend> IpuBackend::instance_ = nullptr;
IpuBackend::IpuBackend() {
compiler_ = std::make_shared<Compiler>();
executor_ = std::make_unique<Executor>();
}
void IpuBackend::Clear() {
executor_.reset();
// detach device
if (device_ != nullptr && device_->isAttached()) {
device_->detach();
device_.reset();
device_ = nullptr;
}
}
IpuBackend::~IpuBackend() { Clear(); }
std::shared_ptr<IpuBackend> IpuBackend::GetInstance() {
if (!instance_) {
instance_.reset(new IpuBackend());
}
return instance_;
}
// This api should only call from python, always return a new object
std::shared_ptr<IpuBackend> IpuBackend::GetNewInstance() {
instance_.reset(new IpuBackend());
return instance_;
}
void IpuBackend::Compile(framework::ir::Graph* graph,
const std::vector<std::string>& feed_list,
const std::vector<std::string>& fetch_list) {
VLOG(10) << "enter IpuBackend::Compile";
compiler_->InitInputs(graph, feed_list);
compiler_->LowerWeights(graph, scope_);
compiler_->LowerBody(graph);
compiler_->InitOutputs(fetch_list);
executor_->SetWeights(compiler_->GetWeights());
VLOG(10) << "leave IpuBackend::Compile";
}
void IpuBackend::Run(const std::vector<const framework::Tensor*>& inputs,
const std::vector<framework::Tensor*>& outputs,
const framework::ExecutionContext& ctx) {
Prepare();
auto inputs_id = compiler_->GetInputs();
auto outputs_id = compiler_->GetOutputs();
executor_->Run(inputs_id, inputs, outputs_id, outputs, ctx);
}
void IpuBackend::Prepare() {
if (is_prepared_) {
return;
} else {
is_prepared_ = true;
}
// convert Model to fp16
if (ipu_strategy_->enable_fp16) {
compiler_->ConvertProtoToFp16();
}
auto proto = compiler_->GetModelProto();
auto tensors = compiler_->GetTensors();
auto outputs = compiler_->GetOutputs();
executor_->Prepare(proto, tensors, outputs, device_);
}
void IpuBackend::SetScope(const framework::Scope& scope) {
scope_ = &scope;
executor_->SetScope(&scope);
}
void IpuBackend::SetIpuStrategy(const IpuStrategy& strategy) {
ipu_strategy_ = &strategy;
executor_->SetIpuStrategy(strategy);
compiler_->SetIpuStrategy(strategy);
}
size_t IpuBackend::GetNumDevices() {
// IpuModel
bool ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
if (ipu_model) return 1;
// Real dev
size_t num_devices =
popart::DeviceManager::createDeviceManager().enumerateDevices().size();
PADDLE_ENFORCE_GT(
num_devices, 0,
platform::errors::Unavailable(
"Do not found any IPU devices, please make "
"sure Poplar sdk is enabled or enable ENV \"POPLAR_IPUMODEL=1\""));
return num_devices;
}
std::vector<int> IpuBackend::GetDeviceIds() {
bool ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
if (ipu_model) {
return {0};
}
std::vector<int> device_ids;
auto devices =
popart::DeviceManager::createDeviceManager().enumerateDevices();
PADDLE_ENFORCE_GT(
devices.size(), 0,
platform::errors::Unavailable("Do not found any IPU devices, please make "
"sure Poplar sdk is enabled."));
for (auto device : devices) {
device_ids.push_back(device->getId());
}
return device_ids;
}
Device IpuBackend::GetDevice(int id) {
bool ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
if (ipu_model) {
std::map<std::string, std::string> deviceOpts{{"numIPUs", "1 "}};
device_ = popart::DeviceManager::createDeviceManager().createIpuModelDevice(
deviceOpts);
Device device(*device_.get());
return device;
}
size_t num_devices = GetNumDevices();
if (id < 0 || id >= num_devices) {
PADDLE_THROW(platform::errors::InvalidArgument(
"device id %d is invalid, number devices is %d", id, num_devices));
}
std::shared_ptr<popart::DeviceInfo> popart_device_info =
popart::DeviceManager::createDeviceManager().getDevice(
popart::SyncPattern::Full, id);
Device device(*popart_device_info.get());
return device;
}
void IpuBackend::AttachDevice(int id) {
// trick here
// Compiler ipu is not same as the runtime ipu.
VLOG(10) << "comile ipu id = " << id;
bool ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
if (ipu_model) {
return;
}
device_ = popart::DeviceManager::createDeviceManager().acquireAvailableDevice(
UpperIpuNum());
PADDLE_ENFORCE_NOT_NULL(
device_, platform::errors::Unavailable("Can't attach IPU, ipu_num = %d.",
UpperIpuNum()));
}
bool IpuBackend::DeviceIsAttached() { return device_ != nullptr; }
// num_ipus must be pow(2,n);
int IpuBackend::UpperIpuNum() {
PADDLE_ENFORCE_GT(ipu_strategy_->num_ipus, 0,
platform::errors::Unavailable(
"The ipu num get is wrong, please make sure the "
"sharding or pipline parameter is right."));
int i = 0;
while (std::pow(2, i) < ipu_strategy_->num_ipus) {
i++;
}
return std::pow(2, i);
}
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <popart/devicemanager.hpp>
#include <popart/names.hpp>
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/ipu/device.h"
#include "paddle/fluid/platform/ipu/ipu_compiler.h"
#include "paddle/fluid/platform/ipu/ipu_executor.h"
#include "paddle/fluid/platform/ipu/ipu_strategy.h"
namespace paddle {
namespace platform {
namespace ipu {
class IpuBackend {
// IpuBackend is the center of paddle-ipu, its function include:
// 1. Compile paddle model to popart model
// 2. Run popart model, inference or training
// 3. Request and release device
// 4. Other helper function
public:
IpuBackend();
~IpuBackend();
void Clear();
// return if exsits, else create and return
static std::shared_ptr<IpuBackend> GetInstance();
// always return a new instance_
static std::shared_ptr<IpuBackend> GetNewInstance();
// what compile does include(call compiler_):
// 1. map paddle-op -> poart op
// 2. construct popart onnx compute graph
void Compile(framework::ir::Graph *graph,
const std::vector<std::string> &feed_list,
const std::vector<std::string> &fetch_list);
// what run does include:
// 1. construct forward onnx graph
// 2. graph-level optimization
// 3. autodiff
void Run(const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs,
const framework::ExecutionContext &ctx);
Executor &GetExecutor() { return *executor_; }
void SetScope(const framework::Scope &scope);
const framework::Scope *GetScope() { return scope_; }
void SetIpuStrategy(const IpuStrategy &strategy);
const IpuStrategy *GetIpuStrategy() { return ipu_strategy_; }
// Device
size_t GetNumDevices();
std::vector<int> GetDeviceIds();
Device GetDevice(int id);
void AttachDevice(int id);
bool DeviceIsAttached();
private:
int UpperIpuNum();
void Prepare();
private:
std::shared_ptr<Compiler> compiler_;
std::unique_ptr<Executor> executor_;
std::shared_ptr<popart::DeviceInfo> device_;
bool is_prepared_ = false;
// not own
const framework::Scope *scope_ = nullptr;
const IpuStrategy *ipu_strategy_ = nullptr;
private:
static std::shared_ptr<IpuBackend> instance_;
};
} // namespace ipu
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/ipu/ipu_compiler.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/ipu/ipu_utils.h"
namespace paddle {
namespace platform {
namespace ipu {
template <typename T>
T GetAttrAllowNull(std::string attr, framework::OpDesc* op_desc) {
if (op_desc->HasAttr(attr)) {
return BOOST_GET_CONST(T, op_desc->GetAttr(attr));
} else {
return {};
}
}
template <typename T>
nonstd::optional<T> GetOptAttrAllowNull(std::string attr,
framework::OpDesc* op_desc) {
if (op_desc->HasAttr(attr)) {
return BOOST_GET_CONST(T, op_desc->GetAttr(attr));
} else {
return {};
}
}
Compiler::Compiler() {
builder_ = popart::Builder::create();
RegisterOpFunc();
}
Compiler::~Compiler() {}
void Compiler::RegisterOpFunc() {
VLOG(10) << "enter Compiler::RegisterOpFunc";
#define INT_VEC std::vector<std::int64_t>
#define FLOAT_VEC std::vector<float>
#define FLOAT float
#define INT std::int64_t
#define BOOL bool
#define STRING std::string
#define STRING_VEC std::vector<std::string*>
#define NONE
#define ARG(Type, Name) , GetAttrAllowNull<Type>(#Name, op_desc)
#define OPT_ARG(Type, Name) , GetOptAttrAllowNull<Type>(#Name, op_desc)
#define POPART_CONST_ARG(Name) , const PopartConstant& Name
#define HOST_SIDE_CONST_ARG(Name) , const HostSideConstant& Name
#define POPART_ATTRIB_VEC_ARG(Name)
#define BODY_ARG(Name) NONE
name_function_ = {
#define OP_DECL(FuncName, OnnxImpl, Args) \
{#FuncName, [&](framework::OpDesc* op_desc) { \
auto op_type = op_desc->Type(); \
VLOG(10) << "build op:" << op_type << " args " << #Args; \
auto inputs = GetOpInputs(op_desc); \
auto output_names = GetOpOutputs(op_desc); \
auto debug_context = BuildDebugContext(op_desc); \
auto aiGraphcoreOpset = builder_->aiGraphcoreOpset1(); \
auto aiOnnxOpset = builder_->aiOnnxOpset11(); \
auto output_ids = OnnxImpl(inputs Args, debug_context); \
SetIpuIndexStage(output_ids, op_desc); \
InsertTensors(output_names, output_ids); \
}}, // NOLINT
#include "paddle/fluid/platform/ipu/supported_ops_autogen.h"
};
#undef OP_DECL
#undef BODY_ARG
#undef POPART_ATTRIB_VEC_ARG
#undef HOST_SIDE_CONST_ARG
#undef POPART_CONST_ARG
#undef OPT_ARG
#undef ARG
#undef NONE
#undef STRING_VEC
#undef STRING
#undef BOOL
#undef INT
#undef FLOAT
#undef FLOAT_VEC
#undef INT_VEC
}
void Compiler::LowerBody(const framework::ir::Graph* graph) {
VLOG(10) << "enter Compiler::LowerBody";
auto nodes = framework::ir::TopologySortOperations(*graph);
for (auto* node : nodes) {
auto* op_desc = node->Op();
auto op_type = op_desc->Type();
VLOG(10) << "node->type: " << op_type;
if (op_type == "popart_constant") {
auto dims =
BOOST_GET_CONST(std::vector<int64_t>, op_desc->GetAttr("dims"));
auto dtype_ = BOOST_GET_CONST(int, op_desc->GetAttr("dtype"));
auto dtype = OnnxDtype2PopartType(dtype_);
popart::TensorInfo tensor_info{dtype, dims};
auto value_attr = op_desc->GetAttr("value");
auto const_data = std::unique_ptr<popart::ConstVoidData>{};
switch (dtype) {
case popart::DataType::FLOAT:
const_data.reset(new popart::ConstVoidData(
BOOST_GET_CONST(std::vector<float>, value_attr).data(),
tensor_info));
break;
case popart::DataType::INT32:
const_data.reset(new popart::ConstVoidData(
BOOST_GET_CONST(std::vector<int>, value_attr).data(),
tensor_info));
break;
case popart::DataType::DOUBLE:
const_data.reset(new popart::ConstVoidData(
BOOST_GET_CONST(std::vector<double>, value_attr).data(),
tensor_info));
break;
case popart::DataType::INT64:
const_data.reset(new popart::ConstVoidData(
BOOST_GET_CONST(std::vector<int64_t>, value_attr).data(),
tensor_info));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"The popart datatype is not supported, popart::DataType is %d",
dtype));
}
popart::TensorId result = builder_->aiOnnxOpset11().constant(*const_data);
SetIpuIndexStage(result, op_desc);
InsertTensors(GetOpOutputs(op_desc), result);
} else if (op_type == "popart_batchnormalization") {
auto inputs = GetOpInputs(op_desc);
auto outputs = GetOpOutputs(op_desc);
auto num_outputs = outputs.size();
auto epsilon = BOOST_GET_CONST(float, op_desc->GetAttr("epsilon"));
auto momentum = BOOST_GET_CONST(float, op_desc->GetAttr("momentum"));
auto result = builder_->aiOnnxOpset11().batchnormalization(
inputs, num_outputs, epsilon, momentum);
SetIpuIndexStage(result, op_desc);
InsertTensors(GetOpOutputs(op_desc), result);
} else if (op_type == "popart_nllloss") {
auto inputs = GetOpInputs(op_desc);
auto ignoreIndex = BOOST_GET_CONST(int, op_desc->GetAttr("ignoreIndex"));
auto result = builder_->aiGraphcoreOpset1().nllloss(
inputs, popart::ReductionType::NoReduction, ignoreIndex);
SetIpuIndexStage(result, op_desc);
InsertTensors(GetOpOutputs(op_desc), result);
} else if (op_type == "popart_topk") {
auto inputs = GetOpInputs(op_desc);
auto outputs = GetOpOutputs(op_desc);
int64_t axis = BOOST_GET_CONST(int64_t, op_desc->GetAttr("axis"));
int sorted_INT32 = BOOST_GET_CONST(int, op_desc->GetAttr("sorted"));
int64_t sorted = int64_t{sorted_INT32};
auto aiOnnxOpset = builder_->aiOnnxOpset11();
popart::ConvInputs result;
if (inputs.size() == 2) {
VLOG(10)
<< "[Compiler::LowerBody] size of inputs for <popart_topk> is 2";
result = aiOnnxOpset.topk(inputs, axis, sorted);
} else if (inputs.size() == 1) {
VLOG(10)
<< "[Compiler::LowerBody] size of inputs for <popart_topk> is 1";
int64_t k = BOOST_GET_CONST(int64_t, op_desc->GetAttr("k"));
popart::TensorInfo kShape{"INT64", std::vector<int64_t>{1}};
popart::ConstVoidData kData = {&k, kShape};
auto K_t = aiOnnxOpset.constant(kData);
result = aiOnnxOpset.topk({inputs[0], K_t}, axis, sorted);
}
result[1] = aiOnnxOpset.cast({result[1]}, "INT32");
SetIpuIndexStage(result, op_desc);
VLOG(10) << "[Compiler::LowerBody] output[1]: " << outputs[1];
VLOG(10) << "[Compiler::LowerBody] output[1]: "
<< GetOpOutputs(op_desc)[1] << " -> " << result[1];
tensors_.emplace(GetOpOutputs(op_desc)[1], result[1]); // topk indices
VLOG(10) << "[Compiler::LowerBody] output[0]: " << outputs[0];
VLOG(10) << "[Compiler::LowerBody] output[0]: "
<< GetOpOutputs(op_desc)[0] << " -> " << result[0];
tensors_.emplace(GetOpOutputs(op_desc)[0], result[0]); // topk values
} else {
auto itr = name_function_.find(op_type);
if (itr != name_function_.end()) {
itr->second(node->Op());
} else {
PADDLE_THROW(platform::errors::NotFound(
"Op %s is not registered in popart canonicalization", op_type));
}
}
}
VLOG(10) << "leave Compiler::LowerBody";
}
void Compiler::InitInputs(framework::ir::Graph* graph,
const std::vector<std::string>& feed_list) {
for (const auto& feed_name : feed_list) {
feed_list_.push_back(feed_name);
for (const framework::ir::Node* n : graph->Nodes()) {
if (n->IsVar()) {
auto* var_desc = n->Var();
if (feed_name == var_desc->Name()) {
VLOG(10) << "feed_name= " << var_desc->Name();
auto data_type = VarType2PopartType(var_desc->GetDataType());
if (ipu_strategy_->enable_fp16) {
data_type = popart::DataType::FLOAT16;
}
popart::TensorInfo input_info{data_type, var_desc->GetShape()};
VLOG(10) << "popart input_info = " << input_info;
popart::TensorId tensor_id =
builder_->addInputTensor(input_info, feed_name);
VLOG(10) << "popart input tensor id = " << tensor_id;
inputs_.push_back(tensor_id);
tensors_.emplace(var_desc->Name(), tensor_id);
}
}
}
}
}
void Compiler::InitOutputs(const std::vector<std::string>& fetch_list) {
for (const auto& fetch_name : fetch_list) {
fetch_list_.push_back(fetch_name);
auto tensor = tensors_.find(fetch_name);
PADDLE_ENFORCE_NE(tensor, tensors_.end(),
platform::errors::NotFound(
"output tensor %s does not exist.", fetch_name));
VLOG(10) << "fetch_name= " << fetch_name;
VLOG(10) << "popart output tensor id = " << tensor->second;
builder_->addOutputTensor(tensor->second);
outputs_.push_back(tensor->second);
}
}
void Compiler::LowerWeights(const framework::ir::Graph* graph,
const framework::Scope* scope_) {
PADDLE_ENFORCE_NOT_NULL(scope_,
platform::errors::PreconditionNotMet(
"You should call set_scope before LowerWeights"));
// at this step, the graph doesn't contains optimizer related states
for (const auto* node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
if (node->Var()->Persistable() && node->inputs.empty()) {
auto var_name = node->Var()->Name();
// workround: https://github.com/graphcore/Paddle/issues/151
if (tensors_.count(var_name) != 0) {
continue;
}
auto var = scope_->FindVar(var_name);
if (var) {
auto tensor = var->Get<framework::LoDTensor>();
auto dtype = VarType2PopartType(tensor.type());
auto shape = std::vector<int64_t>();
for (size_t i = 0; i < tensor.dims().size(); ++i) {
shape.push_back(tensor.dims().at(i));
}
popart::TensorInfo tensor_info(dtype, shape);
popart::ConstVoidData const_data{tensor.data<void>(), tensor_info};
popart::TensorId result =
builder_->addInitializedInputTensor(const_data, var_name);
tensors_.emplace(var_name, result);
weights_.push_back(result);
}
}
}
}
}
void Compiler::InsertTensors(const std::vector<std::string>& output_names,
const std::vector<std::string>& tensor_ids) {
PADDLE_ENFORCE_EQ(output_names.size(), tensor_ids.size(),
platform::errors::Fatal("InsertTensors size mismatch"));
for (int i = 0; i < tensor_ids.size(); i++) {
std::string tensor_id = tensor_ids[i];
tensors_.emplace(output_names[i], tensor_ids[i]);
}
}
void Compiler::InsertTensors(const std::vector<std::string>& output_names,
const std::string& tensor_id) {
PADDLE_ENFORCE_EQ(output_names.size(), 1,
platform::errors::Fatal("InsertTensors size mismatch"));
tensors_.emplace(output_names[0], tensor_id);
}
void Compiler::SetIpuIndexStage(const std::vector<std::string>& tensor_ids,
const framework::OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetIpuIndexStage";
auto tensor_ids_set =
std::set<std::string>(tensor_ids.begin(), tensor_ids.end());
if (op_desc->HasAttr(sIpuIndexAttr)) {
auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuIndexAttr));
builder_->virtualGraph(tensor_ids_set, ipu_index);
VLOG(10) << "set " << sIpuIndexAttr << " = " << ipu_index
<< " for op: " << op_desc->Type();
if (op_desc->HasAttr(sIpuStageAttr)) {
auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuStageAttr));
builder_->pipelineStage(tensor_ids_set, ipu_stage);
VLOG(10) << "set " << sIpuStageAttr << "= " << ipu_stage
<< " for op: " << op_desc->Type();
}
}
VLOG(10) << "leave Compiler::SetIpuIndexStage";
}
void Compiler::SetIpuIndexStage(const std::string& tensor_id,
const framework::OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetIpuIndexStage";
if (op_desc->HasAttr(sIpuIndexAttr)) {
auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuIndexAttr));
builder_->virtualGraph(tensor_id, ipu_index);
VLOG(10) << "set " << sIpuIndexAttr << " = " << ipu_index
<< " for op: " << op_desc->Type();
if (op_desc->HasAttr(sIpuStageAttr)) {
auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuStageAttr));
builder_->pipelineStage(tensor_id, ipu_stage);
VLOG(10) << "set " << sIpuStageAttr << "= " << ipu_stage
<< " for op: " << op_desc->Type();
}
}
VLOG(10) << "leave Compiler::SetIpuIndexStage";
}
std::vector<popart::TensorId>& Compiler::GetWeights() { return weights_; }
// convertFloatsToHalfs
void Compiler::ConvertProtoToFp16() {
popart::GraphTransformer graph_transformer(builder_->getModelProto());
graph_transformer.convertFloatsToHalfs();
converted_proto_ = graph_transformer.getModelProto();
}
std::string Compiler::GetModelProto() {
if (converted_proto_.length()) {
return converted_proto_;
}
return builder_->getModelProto();
}
void Compiler::SaveModelProto(const std::string& path) {
builder_->saveModelProto(path);
}
void Compiler::SaveModelProtoNoCheck(const std::string& path) {
auto proto = GetModelProto();
std::ofstream onnxfile(path, std::ios_base::binary);
onnxfile.write(proto.data(), proto.size());
onnxfile.close();
}
std::vector<std::string> Compiler::GetOpInputs(const framework::OpDesc* op) {
auto ins = op->Input("__inputs__");
std::vector<std::string> inputs;
for (const auto& in : ins) {
if (tensors_.find(in) != tensors_.end()) {
inputs.push_back(tensors_[in]);
} else {
inputs.push_back(in);
}
}
return inputs;
}
const std::vector<std::string>& Compiler::GetOpOutputs(
const framework::OpDesc* op) {
return op->Output("__outputs__");
}
popart::DebugContext Compiler::BuildDebugContext(const framework::OpDesc* op) {
auto op_identify_id =
BOOST_GET_CONST(std::string, op->GetAttr(sOpIdentifyIdAttr));
VLOG(10) << "op_identify_id of op: " << op->Type() << " is "
<< op_identify_id;
return popart::DebugContext(op_identify_id);
}
} // namespace ipu
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <popart/builder.hpp>
#include <popart/graphtransformer.hpp>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/ipu/common.h"
#include "paddle/fluid/platform/ipu/ipu_strategy.h"
namespace paddle {
namespace platform {
namespace ipu {
class Compiler {
public:
Compiler();
~Compiler();
void RegisterOpFunc();
void LowerBody(const framework::ir::Graph *graph);
void InitInputs(framework::ir::Graph *graph,
const std::vector<std::string> &feed_list);
void InitOutputs(const std::vector<std::string> &fetch_list);
void LowerWeights(const framework::ir::Graph *graph,
const framework::Scope *scope_);
void InsertTensors(const std::vector<std::string> &output_names,
const std::vector<std::string> &tensor_ids);
void InsertTensors(const std::vector<std::string> &output_names,
const std::string &tensor_id);
void SetIpuIndexStage(const std::vector<std::string> &tensor_ids,
const framework::OpDesc *op_desc);
void SetIpuIndexStage(const std::string &tensor_id,
const framework::OpDesc *op_desc);
std::vector<popart::TensorId> GetInputs() { return inputs_; }
std::vector<popart::TensorId> GetOutputs() { return outputs_; }
std::map<std::string, popart::TensorId> GetTensors() { return tensors_; }
std::vector<popart::TensorId> &GetWeights();
std::string GetModelProto();
void SetIpuStrategy(const IpuStrategy &strategy) {
ipu_strategy_ = &strategy;
};
void SaveModelProto(const std::string &path);
void SaveModelProtoNoCheck(const std::string &path);
void ConvertProtoToFp16();
private:
std::vector<std::string> GetOpInputs(const framework::OpDesc *op);
const std::vector<std::string> &GetOpOutputs(const framework::OpDesc *op);
popart::DebugContext BuildDebugContext(const framework::OpDesc *op);
private:
std::unique_ptr<popart::Builder> builder_;
using OpFunc = std::function<void(framework::OpDesc *op_desc)>;
std::unordered_map<std::string, OpFunc> name_function_;
// stateful variable
std::map<std::string, popart::TensorId> tensors_;
// feed_list_ & fetch_list save paddle tensor id
std::vector<std::string> feed_list_;
std::vector<std::string> fetch_list_;
// inputs_ & outputs_ save popart tensor id
std::vector<popart::TensorId> inputs_;
std::vector<popart::TensorId> outputs_;
// weights info map
std::vector<popart::TensorId> weights_;
std::string converted_proto_ = "";
const IpuStrategy *ipu_strategy_ = nullptr;
};
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/ipu/ipu_executor.h"
namespace paddle {
namespace platform {
namespace ipu {
Executor::Executor() {}
Executor::~Executor() {}
void Executor::Prepare(const std::string &proto,
const std::map<std::string, popart::TensorId> &tensors,
const std::vector<popart::TensorId> &outputs,
std::shared_ptr<popart::DeviceInfo> device) {
auto art = popart::AnchorReturnType("All");
std::map<popart::TensorId, popart::AnchorReturnType> anchor_ids;
for (const auto &id : outputs) {
anchor_ids.emplace(id, art);
}
auto dataFlow = popart::DataFlow(ipu_strategy_->batches_per_step, anchor_ids);
PADDLE_ENFORCE_NOT_NULL(device, platform::errors::Unavailable(
"IPU device isn't attached, please call "
"IpuBackend::AttachDevice(id) first."));
if (ipu_strategy_ != nullptr && ipu_strategy_->is_training) {
VLOG(10) << "Creating TrainingSession from Onnx Model...";
auto popart_optimizer = GetPopartOptimizer(opt_info);
auto it = tensors.find(opt_info.GetLoss());
PADDLE_ENFORCE_NE(
it, tensors.end(),
paddle::platform::errors::InvalidArgument(
"loss_id = %s doesn't exist in popart graph.", opt_info.GetLoss()));
session_ = popart::TrainingSession::createFromOnnxModel(
proto, dataFlow, it->second, *popart_optimizer, device,
popart::InputShapeInfo(), ipu_strategy_->popart_options_,
popart::Patterns(popart::PatternsLevel::Default));
} else {
VLOG(10) << "Creating InferenceSession from Onnx Model...";
session_ = popart::InferenceSession::createFromOnnxModel(
proto, dataFlow, device, popart::InputShapeInfo(),
ipu_strategy_->popart_options_,
popart::Patterns(popart::PatternsLevel::Default));
}
VLOG(10) << "Creating session from Onnx Model...done";
VLOG(10) << "Preparing session device...";
session_->prepareDevice();
VLOG(10) << "Preparing session device...done";
SetWeightsIO();
VLOG(10) << "Copy weights from paddle to popart...";
WeightsFromPaddle();
VLOG(10) << "Copy weights from paddle to popart...done";
VLOG(10) << "Copy weights from host to device...";
session_->weightsFromHost();
VLOG(10) << "Copy weights from host to device...done";
if (ipu_strategy_->save_init_onnx) {
session_->modelToHost("test_init.onnx");
}
}
void Executor::Run(const std::vector<popart::TensorId> &inputs_id,
const std::vector<const framework::Tensor *> &inputs,
const std::vector<popart::TensorId> &outputs_id,
const std::vector<framework::Tensor *> &outputs,
const framework::ExecutionContext &ctx) {
// inputs
std::map<popart::TensorId, popart::IArray &> popart_inputs;
std::map<popart::TensorId, PaddleIArray> input_wrappers;
for (size_t i = 0; i < inputs.size(); i++) {
auto tensor_id = inputs_id[i];
framework::Tensor *tensor = nullptr;
tensor->ShareDataWith(*inputs[i]);
input_wrappers.emplace(tensor_id, PaddleIArray(tensor));
popart_inputs.emplace(tensor_id, input_wrappers.at(tensor_id));
}
// anchors
std::map<popart::TensorId, popart::IArray &> popart_anchors;
std::map<popart::TensorId, PaddleIArray> anchor_wrappers;
for (size_t i = 0; i < outputs.size(); i++) {
auto tensor_id = outputs_id[i];
framework::Tensor *tensor = nullptr;
tensor->ShareDataWith(*outputs[i]);
// get dims & dtype from session
auto fetch_info = session_->getInfo(tensor_id);
auto output_shape = fetch_info.shape();
if (ipu_strategy_->batches_per_step > 1) {
output_shape.insert(output_shape.begin(),
ipu_strategy_->batches_per_step);
}
tensor->Resize(framework::make_ddim(output_shape));
auto fetch_dtype = fetch_info.dataType();
auto paddle_type = PopartType2VarType(fetch_dtype);
tensor->mutable_data(ctx.GetPlace(), paddle_type);
anchor_wrappers.emplace(tensor_id, PaddleIArray(tensor));
popart_anchors.emplace(tensor_id, anchor_wrappers.at(tensor_id));
}
if (ipu_strategy_ != nullptr && ipu_strategy_->is_training) {
VLOG(10) << "Update optimizer learning rate...";
SetLR(GetLRFromScope());
auto popart_optimizer = GetPopartOptimizer(opt_info);
auto &session = dynamic_cast<popart::TrainingSession &>(*session_);
session.updateOptimizerFromHost(popart_optimizer.get());
}
popart::StepIO stepio(popart_inputs, popart_anchors);
VLOG(10) << "Running...";
session_->run(stepio);
VLOG(10) << "Running...done";
if (ipu_strategy_ != nullptr && ipu_strategy_->is_training) {
session_->weightsToHost();
WeightsToPaddle();
if (ipu_strategy_->save_last_onnx) {
session_->modelToHost("test_last.onnx");
}
}
}
void Executor::SetOptimizerType(const std::string &type) {
opt_info.SetType(type);
}
void Executor::SetLR(float lr_rate) { opt_info.SetLR(lr_rate); }
void Executor::SetOptimizerAttr(const std::string &attr, float value) {
opt_info.SetAttr(attr, value);
}
void Executor::SetLoss(const std::string &loss) { opt_info.SetLoss(loss); }
void Executor::SetLRVarName(const std::string &name) {
opt_info.SetLRVarName(name);
}
void Executor::SetWeights(const std::vector<popart::TensorId> &weights) {
weights_ = weights;
}
void Executor::SetWeightsIO() {
auto opt_type = opt_info.GetType();
auto pre_post_fix = GetOptPrePostfix(opt_type);
for (const auto &weight_id : weights_) {
for (const auto &pair : pre_post_fix) {
if (!IsOptimizerSupported(opt_type)) {
continue;
}
// pair.first : popart prefix, pair.second : paddle postfix
auto popart_var_name = pair.first + weight_id;
auto paddle_var_name = weight_id + pair.second;
if (scope_->FindVar(paddle_var_name) == nullptr) {
continue;
}
auto var = scope_->GetVar(paddle_var_name);
auto data_ptr = var->GetMutable<framework::LoDTensor>()->data<float>();
auto tensor_info = session_->getInfo(popart_var_name);
weights_io_.insert(popart_var_name, {data_ptr, tensor_info});
}
}
}
void Executor::WeightsFromPaddle() { session_->writeWeights(weights_io_); }
void Executor::WeightsToPaddle() { session_->readWeights(weights_io_); }
void Executor::SetIpuStrategy(const IpuStrategy &strategy) {
ipu_strategy_ = &strategy;
}
float Executor::GetLRFromScope() {
auto lr_var = scope_->GetVar(opt_info.GetLRVarName());
auto tensor = lr_var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(tensor.type(), framework::proto::VarType::FP32,
platform::errors::InvalidArgument(
"LR requiree float, but got (%s).", tensor.type()));
return tensor.data<float>()[0];
}
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <popart/dataflow.hpp>
#include <popart/names.hpp>
#include <popart/session.hpp>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/ipu/common.h"
#include "paddle/fluid/platform/ipu/ipu_optimizer.h"
#include "paddle/fluid/platform/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/ipu/ipu_utils.h"
namespace paddle {
namespace platform {
namespace ipu {
class Executor {
public:
Executor();
~Executor();
void Prepare(const std::string &proto,
const std::map<std::string, popart::TensorId> &tensors,
const std::vector<popart::TensorId> &outputs,
std::shared_ptr<popart::DeviceInfo> device);
void Run(const std::vector<popart::TensorId> &inputs_id,
const std::vector<const framework::Tensor *> &inputs,
const std::vector<popart::TensorId> &outputs_id,
const std::vector<framework::Tensor *> &outputs,
const framework::ExecutionContext &ctx);
// Optimizer
void SetOptimizerType(const std::string &type);
void SetOptimizerAttr(const std::string &attr, float value);
void SetLoss(const std::string &loss);
void SetLR(float lr_rate);
void SetLRVarName(const std::string &name);
void SetWeights(const std::vector<popart::TensorId> &info);
void SetWeightsIO();
void WeightsFromPaddle();
void WeightsToPaddle();
// Scope
void SetScope(const framework::Scope *scope) { scope_ = scope; }
// Strategy
void SetIpuStrategy(const IpuStrategy &strategy);
private:
float GetLRFromScope();
public:
OptmizerMetaInfo opt_info;
std::unique_ptr<popart::Session> session_;
private:
const framework::Scope *scope_ = nullptr;
const IpuStrategy *ipu_strategy_ = nullptr;
popart::WeightsIO weights_io_;
std::vector<popart::TensorId> weights_;
};
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/ipu/ipu_optimizer.h"
namespace paddle {
namespace platform {
namespace ipu {
OptmizerMetaInfo::OptmizerMetaInfo() {}
OptmizerMetaInfo::~OptmizerMetaInfo() {}
void OptmizerMetaInfo::SetType(const std::string &type) {
type_ = OptTypeStr2Enum(type);
}
float OptmizerMetaInfo::GetAttr(const std::string &attr,
float default_value) const {
if (attrs_.count(attr) == 0) {
return default_value;
}
return attrs_.at(attr);
}
void OptmizerMetaInfo::SetAttr(const std::string &attr, float value) {
attrs_[attr] = value;
}
OptimizerType OptTypeStr2Enum(const std::string type) {
if (type == "sgd") {
return OptimizerType::SGD;
} else if (type == "adam") {
return OptimizerType::Adam;
} else if (type == "lamb") {
return OptimizerType::Lamb;
} else {
return OptimizerType::Undefined;
}
}
std::unique_ptr<popart::Optimizer> GetPopartOptimizer(
const OptmizerMetaInfo &opt_meta_info) {
auto opt_type = opt_meta_info.GetType();
PADDLE_ENFORCE_NE(
opt_type, OptimizerType::Undefined,
platform::errors::InvalidArgument("Optimizer type have not been set."));
if (opt_type == OptimizerType::SGD) {
auto optimizer = std::make_unique<popart::SGD>(
popart::OptimizerValue(opt_meta_info.GetLR(), false),
popart::OptimizerValue(popart::SGD::getUnsetWeightDecay()),
popart::OptimizerValue(popart::SGD::getUnsetMomentum()),
popart::OptimizerValue(popart::SGD::getUnsetDampening()),
popart::OptimizerValue(popart::SGD::getUnsetVelocityScaling()),
popart::OptimizerValue(popart::SGD::getUnsetLossScaling()));
return optimizer;
} else if (opt_type == OptimizerType::Adam) {
auto optimizer = std::make_unique<popart::Adam>(
popart::OptimizerValue(opt_meta_info.GetLR(), false),
popart::OptimizerValue(popart::Adam::getUnsetWeightDecay()),
popart::OptimizerValue(opt_meta_info.GetAttr("beta1"), false),
popart::OptimizerValue(opt_meta_info.GetAttr("beta2"), false),
popart::OptimizerValue(opt_meta_info.GetAttr("epsilon"), false),
popart::OptimizerValue(popart::Adam::getUnsetLossScaling()),
popart::AdamMode::Adam, popart::WeightDecayMode::Decay,
popart::DataType::FLOAT, popart::DataType::FLOAT,
popart::DataType::FLOAT);
return optimizer;
} else if (opt_type == OptimizerType::Lamb) {
auto optimizer = std::make_unique<popart::Adam>(
popart::OptimizerValue(opt_meta_info.GetLR(), false),
popart::OptimizerValue(opt_meta_info.GetAttr("weight_decay"), false),
popart::OptimizerValue(opt_meta_info.GetAttr("beta1"), false),
popart::OptimizerValue(opt_meta_info.GetAttr("beta2"), false),
popart::OptimizerValue(opt_meta_info.GetAttr("epsilon"), false),
popart::OptimizerValue(popart::Adam::getUnsetLossScaling()),
popart::AdamMode::Lamb, popart::WeightDecayMode::Decay,
popart::DataType::FLOAT, popart::DataType::FLOAT,
popart::DataType::FLOAT);
return optimizer;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Optimizer %d is not implemented now.", static_cast<int>(opt_type)));
}
}
bool IsOptimizerSupported(OptimizerType type) {
switch (type) {
case OptimizerType::SGD:
case OptimizerType::Adam:
case OptimizerType::Lamb:
return true;
default:
return false;
}
}
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
OptimizerType opt_type) {
// format: {popart_tensor_id, paddle_tensor_id}, ...
std::vector<std::pair<std::string, std::string>> pre_post_fix;
switch (opt_type) {
case OptimizerType::SGD:
pre_post_fix.push_back(std::make_pair("", ""));
break;
case OptimizerType::Adam:
case OptimizerType::Lamb:
pre_post_fix.push_back(std::make_pair("", ""));
pre_post_fix.push_back(std::make_pair("Accl1___", "_moment1_0"));
pre_post_fix.push_back(std::make_pair("Accl2___", "_moment2_0"));
pre_post_fix.push_back(std::make_pair("Step___", "_beta1_pow_acc_0"));
break;
default:
pre_post_fix.push_back(std::make_pair("", ""));
break;
}
return pre_post_fix;
}
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <popart/adam.hpp>
#include <popart/names.hpp>
#include <popart/optimizer.hpp>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace ipu {
enum class OptimizerType { SGD = 0, Adam, Lamb, Undefined };
class OptmizerMetaInfo {
public:
OptmizerMetaInfo();
~OptmizerMetaInfo();
void SetType(const std::string &type);
OptimizerType GetType() const { return type_; }
void SetAttr(const std::string &attr, float value);
float GetAttr(const std::string &attr, float default_value = 0.0f) const;
void SetLoss(const std::string &loss) { loss_ = loss; }
std::string GetLoss() const { return loss_; }
void SetLR(float lr_rate) { lr_rate_ = lr_rate; }
float GetLR() const { return lr_rate_; }
void SetLRVarName(const std::string &name) { lr_var_name_ = name; }
std::string GetLRVarName() const { return lr_var_name_; }
private:
// type: adam, sgd, ...
OptimizerType type_ = OptimizerType::Undefined;
// loss: loss TensorId
std::string loss_;
// attrs: beta1, beta2, ...
std::map<std::string, float> attrs_;
// learning rate
float lr_rate_ = 1.0;
std::string lr_var_name_;
};
OptimizerType OptTypeStr2Enum(const std::string type);
std::unique_ptr<popart::Optimizer> GetPopartOptimizer(
const OptmizerMetaInfo &info);
bool IsOptimizerSupported(OptimizerType type);
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
OptimizerType type);
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/ipu/ipu_strategy.h"
namespace paddle {
namespace platform {
namespace ipu {} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <popart/sessionoptions.hpp>
namespace paddle {
namespace platform {
namespace ipu {
using VirtualGraphMode = popart::VirtualGraphMode;
struct IpuStrategy {
int num_ipus = 1;
int batches_per_step = 1;
int batch_size = 1;
bool is_training = true;
bool save_init_onnx = false;
bool save_last_onnx = true;
popart::SessionOptions popart_options_;
bool need_avg_shard = false;
bool enable_fp16 = false;
};
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/ipu/ipu_utils.h"
namespace paddle {
namespace platform {
namespace ipu {
void* PaddleIArray::data() { return tensor_->data<void>(); }
popart::DataType PaddleIArray::dataType() const {
return VarType2PopartType(tensor_->type());
}
std::size_t PaddleIArray::rank() const { return tensor_->dims().size(); }
int64_t PaddleIArray::dim(size_t index) const {
return tensor_->dims().at(index);
}
std::size_t PaddleIArray::nelms() const {
return std::accumulate(shape_.begin(), shape_.end(), static_cast<int64_t>(1),
std::multiplies<int64_t>());
}
const popart::Shape PaddleIArray::shape() const { return shape_; }
popart::DataType VarType2PopartType(
const framework::proto::VarType::Type type) {
switch (type) {
case framework::proto::VarType::UINT8:
return popart::DataType::UINT8;
case framework::proto::VarType::INT8:
return popart::DataType::INT8;
case framework::proto::VarType::INT16:
return popart::DataType::INT16;
case framework::proto::VarType::INT32:
return popart::DataType::INT32;
case framework::proto::VarType::INT64:
return popart::DataType::INT64;
case framework::proto::VarType::BOOL:
return popart::DataType::BOOL;
case framework::proto::VarType::FP64:
return popart::DataType::DOUBLE;
case framework::proto::VarType::FP32:
return popart::DataType::FLOAT;
case framework::proto::VarType::FP16:
return popart::DataType::FLOAT16;
case framework::proto::VarType::BF16:
return popart::DataType::BFLOAT16;
case framework::proto::VarType::COMPLEX64:
return popart::DataType::COMPLEX64;
case framework::proto::VarType::COMPLEX128:
return popart::DataType::COMPLEX128;
default:
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Unsupported Paddle var type."));
}
}
framework::proto::VarType::Type PopartType2VarType(
const popart::DataType type) {
switch (type) {
case popart::DataType::UINT8:
return framework::proto::VarType::UINT8;
case popart::DataType::INT8:
return framework::proto::VarType::INT8;
case popart::DataType::INT16:
return framework::proto::VarType::INT16;
case popart::DataType::INT32:
return framework::proto::VarType::INT32;
case popart::DataType::INT64:
return framework::proto::VarType::INT64;
case popart::DataType::BOOL:
return framework::proto::VarType::BOOL;
case popart::DataType::DOUBLE:
return framework::proto::VarType::FP64;
case popart::DataType::FLOAT:
return framework::proto::VarType::FP32;
case popart::DataType::FLOAT16:
return framework::proto::VarType::FP16;
case popart::DataType::BFLOAT16:
return framework::proto::VarType::BF16;
case popart::DataType::COMPLEX64:
return framework::proto::VarType::COMPLEX64;
case popart::DataType::COMPLEX128:
return framework::proto::VarType::COMPLEX128;
default:
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Unsupported Paddle var type."));
}
}
popart::DataType OnnxDtype2PopartType(const int type) {
auto dtype = static_cast<ONNXDataType>(type);
switch (dtype) {
case ONNXDataType::BOOL:
return popart::DataType::BOOL;
case ONNXDataType::INT16:
return popart::DataType::INT16;
case ONNXDataType::INT32:
return popart::DataType::INT32;
case ONNXDataType::INT64:
return popart::DataType::INT64;
case ONNXDataType::FLOAT16:
return popart::DataType::FLOAT16;
case ONNXDataType::FLOAT:
return popart::DataType::FLOAT;
case ONNXDataType::DOUBLE:
return popart::DataType::DOUBLE;
case ONNXDataType::UINT8:
return popart::DataType::UINT8;
case ONNXDataType::INT8:
return popart::DataType::INT8;
case ONNXDataType::BFLOAT16:
return popart::DataType::BFLOAT16;
case ONNXDataType::COMPLEX64:
return popart::DataType::COMPLEX64;
case ONNXDataType::COMPLEX128:
return popart::DataType::COMPLEX128;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported ONNX data type: %d.", dtype));
}
}
// count num should > 0
bool GetBoolEnv(std::string str) {
char* str_val = getenv(str.c_str());
if (str_val == NULL) {
return false;
} else {
bool val = false;
if (strcmp(str_val, "1") == 0 || strcmp(str_val, "true") == 0 ||
strcmp(str_val, "True") == 0 || strcmp(str_val, "TRUE") == 0)
val = true;
return val;
}
}
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <popart/ndarraywrapper.hpp>
#include <popart/tensordata.hpp>
#include <popart/tensorinfo.hpp>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace platform {
namespace ipu {
// onnx dtype
// https://github.com/onnx/onnx/blob/master/onnx/onnx-ml.proto3
enum ONNXDataType : int {
UNDEFINED = 0,
FLOAT = 1,
UINT8 = 2,
INT8 = 3,
UINT16 = 4,
INT16 = 5,
INT32 = 6,
INT64 = 7,
STRING = 8,
BOOL = 9,
FLOAT16 = 10,
DOUBLE = 11,
UINT32 = 12,
UINT64 = 13,
COMPLEX64 = 14,
COMPLEX128 = 15,
BFLOAT16 = 16
};
class PaddleIArray final : public popart::IArray {
public:
explicit PaddleIArray(framework::Tensor *tensor) : tensor_(tensor) {
for (int i = 0; i < tensor->dims().size(); ++i) {
shape_.push_back(tensor->dims().at(i));
}
}
public:
void *data();
popart::DataType dataType() const;
std::size_t rank() const;
int64_t dim(size_t index) const;
std::size_t nelms() const;
const popart::Shape shape() const;
private:
framework::Tensor *tensor_;
std::vector<int64_t> shape_;
};
popart::DataType VarType2PopartType(const framework::proto::VarType::Type type);
framework::proto::VarType::Type PopartType2VarType(const popart::DataType type);
popart::DataType OnnxDtype2PopartType(const int type);
bool GetBoolEnv(std::string str);
template <typename T>
std::unique_ptr<popart::NDArrayWrapper<T>> Tensor2IArray(
const framework::Tensor &tensor) {
auto dtype = VarType2PopartType(tensor.type());
auto shape = std::vector<int64_t>();
for (size_t i = 0; i < tensor.dims().size(); ++i) {
shape.push_back(tensor.dims().at(i));
}
popart::TensorInfo tensor_info(dtype, shape);
return std::make_unique<popart::NDArrayWrapper<T>>(
reinterpret_cast<T *>(tensor.data<void>()), tensor_info);
}
template <typename T>
std::unique_ptr<popart::NDArrayWrapper<T>> LoDTensor2IArray(
framework::LoDTensor const &lod_tensor) {
PADDLE_ENFORCE_EQ(
lod_tensor.lod().size(), 0UL,
platform::errors::InvalidArgument("LoDTensor2IArray is Unimplemented"));
return Tensor2IArray<T>(lod_tensor);
}
} // namespace ipu
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// clang-format off
#pragma once
// Ops from AiGraphcoreOpset1
OP_DECL(popart_groupnormalization_v2, aiGraphcoreOpset.groupnormalization, ARG(INT,num_groups) ARG(FLOAT,epsilon) ) // NOLINT
OP_DECL(popart_subsample_v2, aiGraphcoreOpset.subsample, ARG(INT_VEC,strides) ) // NOLINT
OP_DECL(popart_nop_v2, aiGraphcoreOpset.nop, NONE) // NOLINT
OP_DECL(popart_scale_v2, aiGraphcoreOpset.scale, ARG(FLOAT,scale) ) // NOLINT
OP_DECL(popart_scaledadd_v2, aiGraphcoreOpset.scaledadd, ARG(FLOAT,scale0) ARG(FLOAT,scale1) ) // NOLINT
OP_DECL(popart_gelu_v2, aiGraphcoreOpset.gelu, NONE) // NOLINT
OP_DECL(popart_detach_v2, aiGraphcoreOpset.detach, NONE) // NOLINT
OP_DECL(popart_depthtospace_v2, aiGraphcoreOpset.depthtospace, ARG(INT,blocksize) ARG(STRING,mode) ) // NOLINT
OP_DECL(popart_round_v2, aiGraphcoreOpset.round, NONE) // NOLINT
OP_DECL(popart_dynamicslice_v2, aiGraphcoreOpset.dynamicslice, ARG(INT_VEC,axes) ARG(INT_VEC,sizes) ARG(INT,noOverlap) ) // NOLINT
OP_DECL(popart_dynamicupdate_v2, aiGraphcoreOpset.dynamicupdate, ARG(INT_VEC,axes) ARG(INT_VEC,sizes) ARG(INT,noOverlap) ) // NOLINT
OP_DECL(popart_dynamiczero_v2, aiGraphcoreOpset.dynamiczero, ARG(INT_VEC,axes) ARG(INT_VEC,sizes) ) // NOLINT
OP_DECL(popart_dynamicadd_v2, aiGraphcoreOpset.dynamicadd, ARG(INT_VEC,axes) ARG(INT_VEC,sizes) ) // NOLINT
OP_DECL(popart_sequenceslice_v2, aiGraphcoreOpset.sequenceslice, ARG(INT,zeroUnused) ) // NOLINT
OP_DECL(popart_replicatedallreduce_v2, aiGraphcoreOpset.replicatedallreduce, OPT_ARG(INT_VEC,commGroup) ) // NOLINT
OP_DECL(popart_ctcbeamsearchdecoder_v2, aiGraphcoreOpset.ctcbeamsearchdecoder, ARG(INT,blank) ARG(INT,beamWidth) ARG(INT,topPaths) ) // NOLINT
OP_DECL(popart_shapeddropout_v2, aiGraphcoreOpset.shapeddropout, ARG(INT_VEC,shape) ARG(FLOAT,ratio) ) // NOLINT
OP_DECL(popart_atan2_v2, aiGraphcoreOpset.atan2, NONE) // NOLINT
OP_DECL(popart_expm1_v2, aiGraphcoreOpset.expm1, NONE) // NOLINT
OP_DECL(popart_log1p_v2, aiGraphcoreOpset.log1p, NONE) // NOLINT
OP_DECL(popart_fmod_v2, aiGraphcoreOpset.fmod, NONE) // NOLINT
OP_DECL(popart_remainder_v2, aiGraphcoreOpset.remainder, NONE) // NOLINT
OP_DECL(popart_reverse_v2, aiGraphcoreOpset.reverse, ARG(INT_VEC,dimensions) ) // NOLINT
OP_DECL(popart_bitwisenot_v2, aiGraphcoreOpset.bitwisenot, NONE) // NOLINT
OP_DECL(popart_bitwiseand_v2, aiGraphcoreOpset.bitwiseand, NONE) // NOLINT
OP_DECL(popart_bitwiseor_v2, aiGraphcoreOpset.bitwiseor, NONE) // NOLINT
OP_DECL(popart_bitwisexor_v2, aiGraphcoreOpset.bitwisexor, NONE) // NOLINT
OP_DECL(popart_bitwisexnor_v2, aiGraphcoreOpset.bitwisexnor, NONE) // NOLINT
OP_DECL(popart_reducemedian_v2, aiGraphcoreOpset.reducemedian, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
// Ops from AiOnnxOpset11
OP_DECL(popart_argmax, aiOnnxOpset.argmax, ARG(INT,axis) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_argmin, aiOnnxOpset.argmin, ARG(INT,axis) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_averagepool, aiOnnxOpset.averagepool, ARG(INT_VEC,kernel_shape) ARG(INT,ceil_mode) ARG(INT,count_include_pad) ARG(INT_VEC,pads) ARG(INT_VEC,strides) ) // NOLINT
OP_DECL(popart_bitshift, aiOnnxOpset.bitshift, ARG(STRING,direction) ) // NOLINT
OP_DECL(popart_clip, aiOnnxOpset.clip, NONE) // NOLINT
OP_DECL(popart_compress, aiOnnxOpset.compress, OPT_ARG(INT,axis) ) // NOLINT
OP_DECL(popart_concat, aiOnnxOpset.concat, ARG(INT,axis) ) // NOLINT
OP_DECL(popart_concatfromsequence, aiOnnxOpset.concatfromsequence, ARG(INT,axis) ARG(INT,new_axis) ) // NOLINT
OP_DECL(popart_conv, aiOnnxOpset.conv, ARG(INT_VEC,dilations) ARG(INT,group) ARG(INT_VEC,kernel_shape) ARG(INT_VEC,pads) ARG(INT_VEC,strides) ) // NOLINT
OP_DECL(popart_convtranspose, aiOnnxOpset.convtranspose, ARG(INT_VEC,dilations) ARG(INT,group) ARG(INT_VEC,kernel_shape) ARG(INT_VEC,output_padding) ARG(INT_VEC,output_shape) ARG(INT_VEC,pads) ARG(INT_VEC,strides) ) // NOLINT
OP_DECL(popart_cumsum, aiOnnxOpset.cumsum, ARG(INT,exclusive) ARG(INT,reverse) ) // NOLINT
OP_DECL(popart_depthtospace, aiOnnxOpset.depthtospace, ARG(INT,blocksize) ARG(STRING,mode) ) // NOLINT
OP_DECL(popart_det, aiOnnxOpset.det, NONE) // NOLINT
OP_DECL(popart_dynamicquantizelinear, aiOnnxOpset.dynamicquantizelinear, NONE) // NOLINT
OP_DECL(popart_equal, aiOnnxOpset.equal, NONE) // NOLINT
OP_DECL(popart_flatten, aiOnnxOpset.flatten, ARG(INT,axis) ) // NOLINT
OP_DECL(popart_gather, aiOnnxOpset.gather, ARG(INT,axis) ) // NOLINT
OP_DECL(popart_gatherelements, aiOnnxOpset.gatherelements, ARG(INT,axis) ) // NOLINT
OP_DECL(popart_gathernd, aiOnnxOpset.gathernd, NONE) // NOLINT
OP_DECL(popart_gemm, aiOnnxOpset.gemm, ARG(FLOAT,alpha) ARG(FLOAT,beta) ARG(INT,transA) ARG(INT,transB) ) // NOLINT
OP_DECL(popart_hardmax, aiOnnxOpset.hardmax, ARG(INT,axis) ) // NOLINT
OP_DECL(popart_logsoftmax, aiOnnxOpset.logsoftmax, ARG(INT,axis) ) // NOLINT
OP_DECL(popart_lppool, aiOnnxOpset.lppool, ARG(INT_VEC,kernel_shape) ARG(INT,p) ARG(INT_VEC,pads) ARG(INT_VEC,strides) ) // NOLINT
OP_DECL(popart_maxpool, aiOnnxOpset.maxpool, ARG(INT,num_outputs) ARG(INT_VEC,kernel_shape) ARG(INT,ceil_mode) ARG(INT_VEC,dilations) ARG(INT_VEC,pads) ARG(INT,storage_order) ARG(INT_VEC,strides) ) // NOLINT
OP_DECL(popart_maxunpool, aiOnnxOpset.maxunpool, ARG(INT_VEC,kernel_shape) ARG(INT_VEC,pads) ARG(INT_VEC,strides) ) // NOLINT
OP_DECL(popart_nonmaxsuppression, aiOnnxOpset.nonmaxsuppression, ARG(INT,center_point_box) ) // NOLINT
OP_DECL(popart_onehot, aiOnnxOpset.onehot, ARG(INT,axis) ) // NOLINT
OP_DECL(popart_pad, aiOnnxOpset.pad, ARG(STRING,mode) ) // NOLINT
OP_DECL(popart_range, aiOnnxOpset.range, NONE) // NOLINT
OP_DECL(popart_reducel1, aiOnnxOpset.reducel1, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_reducel2, aiOnnxOpset.reducel2, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_reducelogsum, aiOnnxOpset.reducelogsum, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_reducelogsumexp, aiOnnxOpset.reducelogsumexp, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_reducemax, aiOnnxOpset.reducemax, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_reducemean, aiOnnxOpset.reducemean, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_reducemin, aiOnnxOpset.reducemin, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_reduceprod, aiOnnxOpset.reduceprod, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_reducesum, aiOnnxOpset.reducesum, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_reducesumsquare, aiOnnxOpset.reducesumsquare, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_resize, aiOnnxOpset.resize, ARG(STRING,coordinate_transformation_mode) ARG(FLOAT,cubic_coeff_a) ARG(INT,exclude_outside) ARG(FLOAT,extrapolation_value) ARG(STRING,mode) ARG(STRING,nearest_mode) ) // NOLINT
OP_DECL(popart_round, aiOnnxOpset.round, NONE) // NOLINT
OP_DECL(popart_scatter, aiOnnxOpset.scatter, ARG(INT,axis) ) // NOLINT
OP_DECL(popart_scatterelements, aiOnnxOpset.scatterelements, ARG(INT,axis) ) // NOLINT
OP_DECL(popart_scatternd, aiOnnxOpset.scatternd, NONE) // NOLINT
OP_DECL(popart_sequenceat, aiOnnxOpset.sequenceat, NONE) // NOLINT
OP_DECL(popart_sequenceconstruct, aiOnnxOpset.sequenceconstruct, NONE) // NOLINT
OP_DECL(popart_sequenceerase, aiOnnxOpset.sequenceerase, NONE) // NOLINT
OP_DECL(popart_sequenceinsert, aiOnnxOpset.sequenceinsert, NONE) // NOLINT
OP_DECL(popart_sequencelength, aiOnnxOpset.sequencelength, NONE) // NOLINT
OP_DECL(popart_slice, aiOnnxOpset.slice, NONE) // NOLINT
OP_DECL(popart_softmax, aiOnnxOpset.softmax, ARG(INT,axis) ) // NOLINT
OP_DECL(popart_split, aiOnnxOpset.split, ARG(INT,num_outputs) ARG(INT,axis) ARG(INT_VEC,split) ) // NOLINT
OP_DECL(popart_splittosequence, aiOnnxOpset.splittosequence, ARG(INT,axis) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_squeeze, aiOnnxOpset.squeeze, ARG(INT_VEC,axes) ) // NOLINT
OP_DECL(popart_topk, aiOnnxOpset.topk, ARG(INT,axis) ARG(INT,largest) ARG(INT,sorted) ) // NOLINT
OP_DECL(popart_unique, aiOnnxOpset.unique, ARG(INT,num_outputs) OPT_ARG(INT,axis) ARG(INT,sorted) ) // NOLINT
OP_DECL(popart_unsqueeze, aiOnnxOpset.unsqueeze, ARG(INT_VEC,axes) ) // NOLINT
// Ops from AiOnnxOpset10
OP_DECL(popart_convinteger, aiOnnxOpset.convinteger, ARG(INT_VEC,dilations) ARG(INT,group) ARG(INT_VEC,kernel_shape) ARG(INT_VEC,pads) ARG(INT_VEC,strides) ) // NOLINT
OP_DECL(popart_dequantizelinear, aiOnnxOpset.dequantizelinear, NONE) // NOLINT
OP_DECL(popart_dropout, aiOnnxOpset.dropout, ARG(INT,num_outputs) ARG(FLOAT,ratio) ) // NOLINT
OP_DECL(popart_isinf, aiOnnxOpset.isinf, ARG(INT,detect_negative) ARG(INT,detect_positive) ) // NOLINT
OP_DECL(popart_matmulinteger, aiOnnxOpset.matmulinteger, NONE) // NOLINT
OP_DECL(popart_mod, aiOnnxOpset.mod, ARG(INT,fmod) ) // NOLINT
OP_DECL(popart_qlinearconv, aiOnnxOpset.qlinearconv, ARG(INT_VEC,dilations) ARG(INT,group) ARG(INT_VEC,kernel_shape) ARG(INT_VEC,pads) ARG(INT_VEC,strides) ) // NOLINT
OP_DECL(popart_qlinearmatmul, aiOnnxOpset.qlinearmatmul, NONE) // NOLINT
OP_DECL(popart_quantizelinear, aiOnnxOpset.quantizelinear, NONE) // NOLINT
OP_DECL(popart_reversesequence, aiOnnxOpset.reversesequence, ARG(INT,batch_axis) ARG(INT,time_axis) ) // NOLINT
OP_DECL(popart_roialign, aiOnnxOpset.roialign, ARG(STRING,mode) ARG(INT,output_height) ARG(INT,output_width) ARG(INT,sampling_ratio) ARG(FLOAT,spatial_scale) ) // NOLINT
OP_DECL(popart_thresholdedrelu, aiOnnxOpset.thresholdedrelu, ARG(FLOAT,alpha) ) // NOLINT
OP_DECL(popart_upsample, aiOnnxOpset.upsample, ARG(STRING,mode) ) // NOLINT
// Ops from AiOnnxOpset9
OP_DECL(popart_acosh, aiOnnxOpset.acosh, NONE) // NOLINT
OP_DECL(popart_asinh, aiOnnxOpset.asinh, NONE) // NOLINT
OP_DECL(popart_atanh, aiOnnxOpset.atanh, NONE) // NOLINT
OP_DECL(popart_batchnormalization, aiOnnxOpset.batchnormalization, ARG(INT,num_outputs) ARG(FLOAT,epsilon) ARG(FLOAT,momentum) ) // NOLINT
OP_DECL(popart_cast, aiOnnxOpset.cast, ARG(STRING,to) ) // NOLINT
OP_DECL(popart_cosh, aiOnnxOpset.cosh, NONE) // NOLINT
OP_DECL(popart_erf, aiOnnxOpset.erf, NONE) // NOLINT
OP_DECL(popart_eyelike, aiOnnxOpset.eyelike, OPT_ARG(INT,dtype) ARG(INT,k) ) // NOLINT
OP_DECL(popart_greater, aiOnnxOpset.greater, NONE) // NOLINT
OP_DECL(popart_isnan, aiOnnxOpset.isnan, NONE) // NOLINT
OP_DECL(popart_less, aiOnnxOpset.less, NONE) // NOLINT
OP_DECL(popart_matmul, aiOnnxOpset.matmul, NONE) // NOLINT
OP_DECL(popart_meanvariancenormalization, aiOnnxOpset.meanvariancenormalization, ARG(INT_VEC,axes) ) // NOLINT
OP_DECL(popart_nonzero, aiOnnxOpset.nonzero, NONE) // NOLINT
OP_DECL(popart_prelu, aiOnnxOpset.prelu, NONE) // NOLINT
OP_DECL(popart_shrink, aiOnnxOpset.shrink, ARG(FLOAT,bias) ARG(FLOAT,lambd) ) // NOLINT
OP_DECL(popart_sign, aiOnnxOpset.sign, NONE) // NOLINT
OP_DECL(popart_sinh, aiOnnxOpset.sinh, NONE) // NOLINT
OP_DECL(popart_where, aiOnnxOpset.where, NONE) // NOLINT
// Ops from AiOnnxOpset8
OP_DECL(popart_expand, aiOnnxOpset.expand, NONE) // NOLINT
OP_DECL(popart_max, aiOnnxOpset.max, NONE) // NOLINT
OP_DECL(popart_mean, aiOnnxOpset.mean, NONE) // NOLINT
OP_DECL(popart_min, aiOnnxOpset.min, NONE) // NOLINT
OP_DECL(popart_sum, aiOnnxOpset.sum, NONE) // NOLINT
// Ops from AiOnnxOpset7
OP_DECL(popart_acos, aiOnnxOpset.acos, NONE) // NOLINT
OP_DECL(popart_add, aiOnnxOpset.add, NONE) // NOLINT
OP_DECL(popart_logical_and, aiOnnxOpset.logical_and, NONE) // NOLINT
OP_DECL(popart_asin, aiOnnxOpset.asin, NONE) // NOLINT
OP_DECL(popart_atan, aiOnnxOpset.atan, NONE) // NOLINT
OP_DECL(popart_cos, aiOnnxOpset.cos, NONE) // NOLINT
OP_DECL(popart_div, aiOnnxOpset.div, NONE) // NOLINT
OP_DECL(popart_mul, aiOnnxOpset.mul, NONE) // NOLINT
OP_DECL(popart_multinomial, aiOnnxOpset.multinomial, ARG(INT,dtype) ARG(INT,sample_size) OPT_ARG(FLOAT,seed) ) // NOLINT
OP_DECL(popart_logical_or, aiOnnxOpset.logical_or, NONE) // NOLINT
OP_DECL(popart_pow, aiOnnxOpset.pow, NONE) // NOLINT
OP_DECL(popart_sin, aiOnnxOpset.sin, NONE) // NOLINT
OP_DECL(popart_sub, aiOnnxOpset.sub, NONE) // NOLINT
OP_DECL(popart_tan, aiOnnxOpset.tan, NONE) // NOLINT
OP_DECL(popart_logical_xor, aiOnnxOpset.logical_xor, NONE) // NOLINT
// Ops from AiOnnxOpset6
OP_DECL(popart_abs, aiOnnxOpset.abs, NONE) // NOLINT
OP_DECL(popart_ceil, aiOnnxOpset.ceil, NONE) // NOLINT
OP_DECL(popart_elu, aiOnnxOpset.elu, ARG(FLOAT,alpha) ) // NOLINT
OP_DECL(popart_exp, aiOnnxOpset.exp, NONE) // NOLINT
OP_DECL(popart_floor, aiOnnxOpset.floor, NONE) // NOLINT
OP_DECL(popart_globalaveragepool, aiOnnxOpset.globalaveragepool, NONE) // NOLINT
OP_DECL(popart_globallppool, aiOnnxOpset.globallppool, ARG(INT,p) ) // NOLINT
OP_DECL(popart_globalmaxpool, aiOnnxOpset.globalmaxpool, NONE) // NOLINT
OP_DECL(popart_hardsigmoid, aiOnnxOpset.hardsigmoid, ARG(FLOAT,alpha) ARG(FLOAT,beta) ) // NOLINT
OP_DECL(popart_identity, aiOnnxOpset.identity, NONE) // NOLINT
OP_DECL(popart_instancenormalization, aiOnnxOpset.instancenormalization, ARG(FLOAT,epsilon) ) // NOLINT
OP_DECL(popart_lrn, aiOnnxOpset.lrn, ARG(INT,size) ARG(FLOAT,alpha) ARG(FLOAT,beta) ARG(FLOAT,bias) ) // NOLINT
OP_DECL(popart_leakyrelu, aiOnnxOpset.leakyrelu, ARG(FLOAT,alpha) ) // NOLINT
OP_DECL(popart_log, aiOnnxOpset.log, NONE) // NOLINT
OP_DECL(popart_lpnormalization, aiOnnxOpset.lpnormalization, ARG(INT,axis) ARG(INT,p) ) // NOLINT
OP_DECL(popart_maxroipool, aiOnnxOpset.maxroipool, ARG(INT_VEC,pooled_shape) ARG(FLOAT,spatial_scale) ) // NOLINT
OP_DECL(popart_neg, aiOnnxOpset.neg, NONE) // NOLINT
OP_DECL(popart_logical_not, aiOnnxOpset.logical_not, NONE) // NOLINT
OP_DECL(popart_randomnormallike, aiOnnxOpset.randomnormallike, OPT_ARG(INT,dtype) ARG(FLOAT,mean) ARG(FLOAT,scale) OPT_ARG(FLOAT,seed) ) // NOLINT
OP_DECL(popart_randomuniformlike, aiOnnxOpset.randomuniformlike, OPT_ARG(INT,dtype) ARG(FLOAT,high) ARG(FLOAT,low) OPT_ARG(FLOAT,seed) ) // NOLINT
OP_DECL(popart_reciprocal, aiOnnxOpset.reciprocal, NONE) // NOLINT
OP_DECL(popart_relu, aiOnnxOpset.relu, NONE) // NOLINT
OP_DECL(popart_reshape, aiOnnxOpset.reshape, NONE) // NOLINT
OP_DECL(popart_selu, aiOnnxOpset.selu, ARG(FLOAT,alpha) ARG(FLOAT,gamma) ) // NOLINT
OP_DECL(popart_shape, aiOnnxOpset.shape, NONE) // NOLINT
OP_DECL(popart_sigmoid, aiOnnxOpset.sigmoid, NONE) // NOLINT
OP_DECL(popart_size, aiOnnxOpset.size, NONE) // NOLINT
OP_DECL(popart_softplus, aiOnnxOpset.softplus, NONE) // NOLINT
OP_DECL(popart_softsign, aiOnnxOpset.softsign, NONE) // NOLINT
OP_DECL(popart_spacetodepth, aiOnnxOpset.spacetodepth, ARG(INT,blocksize) ) // NOLINT
OP_DECL(popart_sqrt, aiOnnxOpset.sqrt, NONE) // NOLINT
OP_DECL(popart_tanh, aiOnnxOpset.tanh, NONE) // NOLINT
OP_DECL(popart_tile, aiOnnxOpset.tile, NONE) // NOLINT
OP_DECL(popart_transpose, aiOnnxOpset.transpose, ARG(INT_VEC,perm) ) // NOLINT
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册