未验证 提交 1b5647d7 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] clean ipu related code (#42511)

* clean code

* fix ci

* fix ci

* fix ci 2
上级 6ff35e17
...@@ -13,12 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_backend.h" #include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/platform/device/ipu/ipu_compiler.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/platform/device/ipu/ipu_executor.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -40,7 +38,7 @@ IpuBackend::~IpuBackend() { ...@@ -40,7 +38,7 @@ IpuBackend::~IpuBackend() {
executor_.reset(); executor_.reset();
} }
void IpuBackend::Compile(Graph* graph, void IpuBackend::Compile(framework::ir::Graph* graph,
const std::vector<std::string>& feed_list, const std::vector<std::string>& feed_list,
const std::vector<std::string>& fetch_list) { const std::vector<std::string>& fetch_list) {
VLOG(10) << "enter IpuBackend::Compile"; VLOG(10) << "enter IpuBackend::Compile";
...@@ -63,8 +61,8 @@ void IpuBackend::Compile(Graph* graph, ...@@ -63,8 +61,8 @@ void IpuBackend::Compile(Graph* graph,
VLOG(10) << "leave IpuBackend::Compile"; VLOG(10) << "leave IpuBackend::Compile";
} }
void IpuBackend::Run(const std::vector<const Tensor*>& inputs, void IpuBackend::Run(const std::vector<const framework::Tensor*>& inputs,
const std::vector<Tensor*>& outputs, const std::vector<framework::Tensor*>& outputs,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
timer_->Start(); timer_->Start();
executor_->Run(inputs, outputs, ctx); executor_->Run(inputs, outputs, ctx);
...@@ -82,7 +80,7 @@ void IpuBackend::Reset() { ...@@ -82,7 +80,7 @@ void IpuBackend::Reset() {
executor_.reset(); executor_.reset();
} }
void IpuBackend::SetScope(const Scope& scope) { void IpuBackend::SetScope(const framework::Scope& scope) {
scope_ = &scope; scope_ = &scope;
executor_->SetScope(&scope); executor_->SetScope(&scope);
} }
......
...@@ -18,26 +18,25 @@ limitations under the License. */ ...@@ -18,26 +18,25 @@ limitations under the License. */
#include <popart/names.hpp> #include <popart/names.hpp>
#include <popart/tensorinfo.hpp> #include <popart/tensorinfo.hpp>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device/ipu/ipu_compiler.h"
#include "paddle/fluid/platform/device/ipu/ipu_device.h"
#include "paddle/fluid/platform/device/ipu/ipu_executor.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h" #include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
} // namespace paddle
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
// IpuBackend is the center of paddle-ipu, its function include: class IpuStrategy;
// 1. Compile paddle model to popart model class Compiler;
// 2. Run popart model, inference or training class Executor;
// 3. Request and release device
// 4. Other helper function
class IpuBackend { class IpuBackend {
public: public:
static IpuBackend *GetInstance(); static IpuBackend *GetInstance();
...@@ -46,47 +45,46 @@ class IpuBackend { ...@@ -46,47 +45,46 @@ class IpuBackend {
IpuBackend(); IpuBackend();
~IpuBackend(); ~IpuBackend();
// what compile does include(call compiler_): // What compile method does:
// 1. map paddle-op -> poart op // Convert paddle ops to popart ops;
// 2. construct popart onnx compute graph // Construct a popart graph, which is a onnx compute graph;
void Compile(Graph *graph, const std::vector<std::string> &feed_list, // Load the graph and weights to ipu.
void Compile(framework::ir::Graph *graph,
const std::vector<std::string> &feed_list,
const std::vector<std::string> &fetch_list); const std::vector<std::string> &fetch_list);
// what run does include: // Run the compiled graph on ipu
// 1. construct forward onnx graph void Run(const std::vector<const framework::Tensor *> &inputs,
// 2. graph-level optimization const std::vector<framework::Tensor *> &outputs,
// 3. autodiff
void Run(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs,
const framework::ExecutionContext &ctx); const framework::ExecutionContext &ctx);
// Sync weights from IPU while training // Sync weights from IPU while training
void WeightsToHost(); void WeightsToHost();
// detach IPU manually // Detach IPU manually
void Detach(); void Detach();
// reset manually // Reset manually
// call it before destruct works // Call it before destruct works
void Reset(); void Reset();
void SetScope(const Scope &scope); void SetScope(const framework::Scope &scope);
const Scope *GetScope() { return scope_; } const framework::Scope *GetScope() { return scope_; }
void SetIpuStrategy(const IpuStrategy &strategy); void SetIpuStrategy(const IpuStrategy &strategy);
const IpuStrategy *GetIpuStrategy() { return ipu_strategy_; } const IpuStrategy *GetIpuStrategy() { return ipu_strategy_; }
// save compiled model to onnx // Save compiled model to onnx
void SaveModelProto(const std::string &path); void SaveModelProto(const std::string &path);
private: private:
// not own // Not own
const Scope *scope_ = nullptr; const framework::Scope *scope_ = nullptr;
const IpuStrategy *ipu_strategy_ = nullptr; const IpuStrategy *ipu_strategy_ = nullptr;
// own // Own
std::unique_ptr<Compiler> compiler_; std::unique_ptr<Compiler> compiler_;
std::unique_ptr<Executor> executor_; std::unique_ptr<Executor> executor_;
std::unique_ptr<platform::Timer> timer_; std::unique_ptr<Timer> timer_;
bool is_compiled_ = false; bool is_compiled_ = false;
......
...@@ -20,12 +20,110 @@ ...@@ -20,12 +20,110 @@
#include <popart/sgd.hpp> #include <popart/sgd.hpp>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h" #include "paddle/fluid/platform/device/ipu/ipu_utils.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
namespace {
struct CustomOpAttrVisitor : public boost::static_visitor<void> {
CustomOpAttrVisitor(std::map<std::string, popart::any>* attr,
const std::string& attr_name)
: attrs_(attr), attr_name_(attr_name) {}
mutable std::map<std::string, popart::any>* attrs_;
std::string attr_name_;
void operator()(int v) const { attrs_->emplace(attr_name_, v); }
void operator()(float v) const { attrs_->emplace(attr_name_, v); }
void operator()(const std::string& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(const std::vector<int>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(const std::vector<float>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(const std::vector<std::string>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(bool v) const { attrs_->emplace(attr_name_, v); }
void operator()(const std::vector<bool>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(BlockDesc* desc) const {
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported calling method for `BlockDesc` type when extracting "
"custom operator attributes."));
}
void operator()(const std::vector<BlockDesc*>& v) const {
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported calling method for `BlockDesc` type when extracting "
"custom operator attributes."));
}
void operator()(int64_t v) const { attrs_->emplace(attr_name_, v); }
void operator()(const std::vector<int64_t>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(const std::vector<double>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(boost::blank) const {
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported calling method for `boost::blank` type when extracting "
"custom operator attributes."));
}
};
struct ConstantOpAttrVisitor : public boost::static_visitor<void> {
ConstantOpAttrVisitor(framework::LoDTensor* tensor, VarType::Type dtype)
: tensor_(tensor), dtype_(dtype) {}
framework::LoDTensor* tensor_;
VarType::Type dtype_;
void operator()(const std::vector<int>& vec) const {
framework::TensorFromVector<int>(vec, tensor_);
}
void operator()(const std::vector<float>& vec) const {
if (dtype_ == VarType::FP16) {
std::vector<float16> vec_fp16;
std::transform(vec.begin(), vec.end(), std::back_inserter(vec_fp16),
[](float f) -> float16 { return float16(f); });
framework::TensorFromVector<float16>(vec_fp16, tensor_);
} else {
framework::TensorFromVector<float>(vec, tensor_);
}
}
void operator()(const std::vector<bool>& vec) const {
framework::TensorFromVector<bool>(vec, tensor_);
}
void operator()(const std::vector<int64_t>& vec) const {
framework::TensorFromVector<int64_t>(vec, tensor_);
}
void operator()(const std::vector<double>& vec) const {
framework::TensorFromVector<double>(vec, tensor_);
}
#define RAISE_ERROR \
PADDLE_THROW( \
platform::errors::InvalidArgument("Constant value must be a vector"))
void operator()(int v) const { RAISE_ERROR; }
void operator()(float v) const { RAISE_ERROR; }
void operator()(const std::string& v) const { RAISE_ERROR; }
void operator()(const std::vector<std::string>& v) const { RAISE_ERROR; }
void operator()(bool v) const { RAISE_ERROR; }
void operator()(BlockDesc* desc) const { RAISE_ERROR; }
void operator()(const std::vector<BlockDesc*>& v) const { RAISE_ERROR; }
void operator()(int64_t v) const { RAISE_ERROR; }
void operator()(boost::blank) const { RAISE_ERROR; }
#undef RAISE_ERROR
};
popart::AdamMode AdamModeFromStr(const std::string& str, popart::AdamMode AdamModeFromStr(const std::string& str,
const bool& use_no_bias_optimizer) { const bool& use_no_bias_optimizer) {
if (str == "adam") { if (str == "adam") {
...@@ -117,6 +215,34 @@ TO GetCastSigAttrAllowNull(std::string attr, OpDesc* op_desc) { ...@@ -117,6 +215,34 @@ TO GetCastSigAttrAllowNull(std::string attr, OpDesc* op_desc) {
} }
} }
// Helper for adding namescope info
struct NameScopeHelper {
NameScopeHelper(const OpDesc* op, popart::Builder* builder);
~NameScopeHelper() {
if (pushed_) {
builder_->popNameScope();
}
}
bool pushed_ = false;
popart::Builder* builder_;
};
NameScopeHelper::NameScopeHelper(const OpDesc* op, popart::Builder* builder)
: builder_(builder) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope.empty() || op_namescope == "/") {
return;
}
op_namescope.pop_back();
op_namescope.erase(op_namescope.begin());
builder->pushNameScope(op_namescope);
pushed_ = true;
}
} // namespace
GraphHelper::GraphHelper(const Graph* g) { GraphHelper::GraphHelper(const Graph* g) {
graph = g; graph = g;
sorted_ops = framework::ir::TopologySortOperations(*g); sorted_ops = framework::ir::TopologySortOperations(*g);
...@@ -181,14 +307,12 @@ void Compiler::RegisterOpFunc() { ...@@ -181,14 +307,12 @@ void Compiler::RegisterOpFunc() {
auto op_type = op_desc->Type(); \ auto op_type = op_desc->Type(); \
VLOG(10) << "build op:" << op_type << " args " << #Args; \ VLOG(10) << "build op:" << op_type << " args " << #Args; \
auto inputs = GetOpInputs(op_desc); \ auto inputs = GetOpInputs(op_desc); \
auto output_names = GetOpOutputs(op_desc); \
auto debug_context = BuildDebugContext(op_desc); \ auto debug_context = BuildDebugContext(op_desc); \
auto aiGraphcoreOpset = builder_->aiGraphcoreOpset1(); \ auto aiGraphcoreOpset = builder_->aiGraphcoreOpset1(); \
auto aiOnnxOpset = builder_->aiOnnxOpset11(); \ auto aiOnnxOpset = builder_->aiOnnxOpset11(); \
NameScopeHelper ns_helper(op_desc, builder_.get()); \ NameScopeHelper ns_helper(op_desc, builder_.get()); \
auto output_ids = OnnxImpl(inputs Args, debug_context); \ auto output_ids = OnnxImpl(inputs Args, debug_context); \
PostLower(output_ids, op_desc); \ PostLower(output_ids, op_desc); \
InsertTensors(output_names, output_ids); \
}}, // NOLINT }}, // NOLINT
#include "paddle/fluid/platform/device/ipu/supported_ops_autogen.h" #include "paddle/fluid/platform/device/ipu/supported_ops_autogen.h"
#include "paddle/fluid/platform/device/ipu/supported_ops_custom.h" #include "paddle/fluid/platform/device/ipu/supported_ops_custom.h"
...@@ -219,7 +343,7 @@ void Compiler::InitInputs(const std::vector<std::string>& feed_list) { ...@@ -219,7 +343,7 @@ void Compiler::InitInputs(const std::vector<std::string>& feed_list) {
auto* node = graph_helper_->vars_name_map[feed_name]; auto* node = graph_helper_->vars_name_map[feed_name];
auto* var_desc = node->Var(); auto* var_desc = node->Var();
VLOG(10) << "feed_name= " << var_desc->Name(); VLOG(10) << "feed_name= " << var_desc->Name();
auto data_type = VarType2PopartType(var_desc->GetDataType()); auto data_type = VarType2PopartDType(var_desc->GetDataType());
popart::TensorInfo input_info{data_type, var_desc->GetShape()}; popart::TensorInfo input_info{data_type, var_desc->GetShape()};
VLOG(10) << "popart input_info = " << input_info; VLOG(10) << "popart input_info = " << input_info;
popart::TensorId tensor_id = popart::TensorId tensor_id =
...@@ -255,8 +379,9 @@ void Compiler::LowerConstants(const Scope* scope) { ...@@ -255,8 +379,9 @@ void Compiler::LowerConstants(const Scope* scope) {
auto shape = auto shape =
BOOST_GET_CONST(std::vector<int64_t>, op_desc->GetAttr("dims")); BOOST_GET_CONST(std::vector<int64_t>, op_desc->GetAttr("dims"));
auto dtype_ = BOOST_GET_CONST(int, op_desc->GetAttr("dtype")); auto dtype_ = BOOST_GET_CONST(int, op_desc->GetAttr("dtype"));
auto dtype = PopartType2VarType(OnnxDtype2PopartType(dtype_)); auto dtype = PopartDType2VarType(
auto tensor_name = op_desc->Output("__outputs__")[0]; OnnxDType2PopartType(static_cast<ONNXDataType>(dtype_)));
auto tensor_name = GetOpOutputs(op_desc).front();
auto* var = kid_scope.Var(tensor_name); auto* var = kid_scope.Var(tensor_name);
VLOG(10) << "lowering constant: " << tensor_name; VLOG(10) << "lowering constant: " << tensor_name;
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
...@@ -267,7 +392,7 @@ void Compiler::LowerConstants(const Scope* scope) { ...@@ -267,7 +392,7 @@ void Compiler::LowerConstants(const Scope* scope) {
tensor->Resize(ddim); tensor->Resize(ddim);
auto const_data = std::unique_ptr<popart::ConstVoidData>(); auto const_data = std::unique_ptr<popart::ConstVoidData>();
popart::TensorInfo tensor_info(PdDataType2PopartType(tensor->dtype()), popart::TensorInfo tensor_info(PhiDType2PopartDType(tensor->dtype()),
shape); shape);
const_data.reset(new popart::ConstVoidData(tensor->data(), tensor_info)); const_data.reset(new popart::ConstVoidData(tensor->data(), tensor_info));
NameScopeHelper ns_helper(op_desc, builder_.get()); NameScopeHelper ns_helper(op_desc, builder_.get());
...@@ -303,7 +428,7 @@ void Compiler::LowerWeights(const Scope* scope) { ...@@ -303,7 +428,7 @@ void Compiler::LowerWeights(const Scope* scope) {
var, platform::errors::NotFound("Tensor %s is not found in the scope", var, platform::errors::NotFound("Tensor %s is not found in the scope",
var_name)); var_name));
auto tensor = var->Get<framework::LoDTensor>(); auto tensor = var->Get<framework::LoDTensor>();
auto dtype = PdDataType2PopartType(tensor.dtype()); auto dtype = PhiDType2PopartDType(tensor.dtype());
auto shape = std::vector<int64_t>(); auto shape = std::vector<int64_t>();
for (size_t i = 0; i < tensor.dims().size(); ++i) { for (size_t i = 0; i < tensor.dims().size(); ++i) {
shape.push_back(tensor.dims().at(i)); shape.push_back(tensor.dims().at(i));
...@@ -336,11 +461,9 @@ void Compiler::LowerBody() { ...@@ -336,11 +461,9 @@ void Compiler::LowerBody() {
// pass // pass
} else if (op_type == "popart_checkpointoutput") { } else if (op_type == "popart_checkpointoutput") {
auto inputs = GetOpInputs(op_desc); auto inputs = GetOpInputs(op_desc);
auto outputs = GetOpOutputs(op_desc);
NameScopeHelper ns_helper(op_desc, builder_.get()); NameScopeHelper ns_helper(op_desc, builder_.get());
auto output_ids = builder_->checkpointOutput(inputs); auto output_ids = builder_->checkpointOutput(inputs);
PostLower(output_ids, op_desc); PostLower(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else if (op_type == "popart_custom_op") { } else if (op_type == "popart_custom_op") {
auto inputs = GetOpInputs(op_desc); auto inputs = GetOpInputs(op_desc);
auto outputs = GetOpOutputs(op_desc); auto outputs = GetOpOutputs(op_desc);
...@@ -359,10 +482,8 @@ void Compiler::LowerBody() { ...@@ -359,10 +482,8 @@ void Compiler::LowerBody() {
builder_->customOp(it->second.popart_op, it->second.popart_op.version, builder_->customOp(it->second.popart_op, it->second.popart_op.version,
inputs, outputs.size(), attributes, debug_context); inputs, outputs.size(), attributes, debug_context);
PostLower(output_ids, op_desc); PostLower(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else if (op_type == "popart_printtensor") { } else if (op_type == "popart_printtensor") {
auto inputs = GetOpInputs(op_desc); auto inputs = GetOpInputs(op_desc);
auto outputs = GetOpOutputs(op_desc);
auto debug_context = BuildDebugContext(op_desc); auto debug_context = BuildDebugContext(op_desc);
auto print_gradient = auto print_gradient =
BOOST_GET_CONST(int64_t, op_desc->GetAttr("print_gradient")); BOOST_GET_CONST(int64_t, op_desc->GetAttr("print_gradient"));
...@@ -371,7 +492,6 @@ void Compiler::LowerBody() { ...@@ -371,7 +492,6 @@ void Compiler::LowerBody() {
auto output_ids = builder_->aiGraphcoreOpset1().printtensor( auto output_ids = builder_->aiGraphcoreOpset1().printtensor(
inputs, print_gradient, debug_context, title); inputs, print_gradient, debug_context, title);
PostLower(output_ids, op_desc); PostLower(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else { } else {
auto itr = name_function_.find(op_type); auto itr = name_function_.find(op_type);
if (itr != name_function_.end()) { if (itr != name_function_.end()) {
...@@ -601,23 +721,6 @@ void Compiler::LowerOptimizer(const Scope* scope) { ...@@ -601,23 +721,6 @@ void Compiler::LowerOptimizer(const Scope* scope) {
} }
} }
void Compiler::InsertTensors(const std::vector<std::string>& output_names,
const std::vector<std::string>& tensor_ids) {
PADDLE_ENFORCE_EQ(output_names.size(), tensor_ids.size(),
platform::errors::Fatal("InsertTensors size mismatch"));
for (int i = 0; i < tensor_ids.size(); i++) {
std::string tensor_id = tensor_ids[i];
resources_->tensors.emplace(output_names[i], tensor_ids[i]);
}
}
void Compiler::InsertTensors(const std::vector<std::string>& output_names,
const std::string& tensor_id) {
PADDLE_ENFORCE_EQ(output_names.size(), 1,
platform::errors::Fatal("InsertTensors size mismatch"));
resources_->tensors.emplace(output_names[0], tensor_id);
}
void Compiler::PostLower(const std::vector<std::string>& tensor_ids, void Compiler::PostLower(const std::vector<std::string>& tensor_ids,
const OpDesc* op_desc) { const OpDesc* op_desc) {
// Set pipline // Set pipline
...@@ -637,13 +740,26 @@ void Compiler::PostLower(const std::vector<std::string>& tensor_ids, ...@@ -637,13 +740,26 @@ void Compiler::PostLower(const std::vector<std::string>& tensor_ids,
<< " for op: " << op_desc->Type(); << " for op: " << op_desc->Type();
} }
} }
// Record output tensors
auto pd_outs = GetOpOutputs(op_desc);
PADDLE_ENFORCE_EQ(
pd_outs.size(), tensor_ids.size(),
platform::errors::Fatal("paddle and popart op have different outputs"));
for (int i = 0; i < tensor_ids.size(); ++i) {
resources_->tensors.emplace(pd_outs[i], tensor_ids[i]);
}
for (auto& tensor_id : tensor_ids) { for (auto& tensor_id : tensor_ids) {
PostLower(tensor_id, op_desc, true); PostLower(tensor_id, op_desc, true);
} }
} }
void Compiler::PostLower(const std::string& tensor_id, const OpDesc* op_desc) { void Compiler::PostLower(const std::string& tensor_id, const OpDesc* op_desc) {
// Record output tensor
auto pd_outs = GetOpOutputs(op_desc);
PADDLE_ENFORCE_EQ(
pd_outs.size(), 1,
platform::errors::Fatal("paddle and popart op have different outputs"));
resources_->tensors.emplace(pd_outs[0], tensor_id);
PostLower(tensor_id, op_desc, false); PostLower(tensor_id, op_desc, false);
} }
...@@ -718,13 +834,7 @@ std::string Compiler::GetFP16ModelProto() { ...@@ -718,13 +834,7 @@ std::string Compiler::GetFP16ModelProto() {
return graph_transformer.getModelProto(); return graph_transformer.getModelProto();
} }
std::string Compiler::GetModelProto() { std::string Compiler::GetModelProto() { return builder_->getModelProto(); }
if (ipu_strategy_->enable_fp16) {
return GetFP16ModelProto();
} else {
return builder_->getModelProto();
}
}
void Compiler::SaveModelProto(const std::string& path) { void Compiler::SaveModelProto(const std::string& path) {
builder_->saveModelProto(path); builder_->saveModelProto(path);
......
...@@ -17,16 +17,15 @@ ...@@ -17,16 +17,15 @@
#include <popart/builder.hpp> #include <popart/builder.hpp>
#include <popart/graphtransformer.hpp> #include <popart/graphtransformer.hpp>
#include <popart/optimizer.hpp> #include <popart/optimizer.hpp>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h" #include "paddle/fluid/platform/device/ipu/ipu_utils.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
class IpuStrategy;
struct CompilerResources { struct CompilerResources {
// popart input tensor_ids // popart input tensor_ids
std::vector<popart::TensorId> inputs; std::vector<popart::TensorId> inputs;
...@@ -81,30 +80,6 @@ struct GraphHelper { ...@@ -81,30 +80,6 @@ struct GraphHelper {
std::vector<int> sorted_vars_id; std::vector<int> sorted_vars_id;
}; };
// Helper for adding namescope info
struct NameScopeHelper {
NameScopeHelper(const OpDesc *op, popart::Builder *builder)
: builder_(builder) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope.empty() || op_namescope == "/") {
return;
}
op_namescope.pop_back();
op_namescope.erase(op_namescope.begin());
builder->pushNameScope(op_namescope);
pushed_ = true;
}
~NameScopeHelper() {
if (pushed_) {
builder_->popNameScope();
}
}
bool pushed_ = false;
popart::Builder *builder_;
};
class Compiler { class Compiler {
public: public:
Compiler(); Compiler();
...@@ -138,11 +113,6 @@ class Compiler { ...@@ -138,11 +113,6 @@ class Compiler {
const std::vector<std::string> &GetOpOutputs(const OpDesc *op); const std::vector<std::string> &GetOpOutputs(const OpDesc *op);
const std::string GetNameScope(const OpDesc *op); const std::string GetNameScope(const OpDesc *op);
popart::DebugContext BuildDebugContext(const OpDesc *op); popart::DebugContext BuildDebugContext(const OpDesc *op);
void InsertTensors(const std::vector<std::string> &output_names,
const std::vector<std::string> &tensor_ids);
void InsertTensors(const std::vector<std::string> &output_names,
const std::string &tensor_id);
void PostLower(const std::vector<std::string> &, const OpDesc *); void PostLower(const std::vector<std::string> &, const OpDesc *);
void PostLower(const std::string &, const OpDesc *); void PostLower(const std::string &, const OpDesc *);
void PostLower(const std::string &, const OpDesc *, bool); void PostLower(const std::string &, const OpDesc *, bool);
......
...@@ -13,14 +13,17 @@ See the License for the specific language governing permissions and ...@@ -13,14 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_device.h" #include "paddle/fluid/platform/device/ipu/ipu_device.h"
#include <popart/devicemanager.hpp>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
// TODO(alleng) merge with ipu_utils namespace {
static bool GetBoolEnv(std::string str) { const bool GetBoolEnv(const std::string& str) {
char* str_val = getenv(str.c_str()); char* str_val = getenv(str.c_str());
if (str_val == NULL) { if (str_val == NULL) {
return false; return false;
...@@ -32,6 +35,7 @@ static bool GetBoolEnv(std::string str) { ...@@ -32,6 +35,7 @@ static bool GetBoolEnv(std::string str) {
return val; return val;
} }
} }
} // namespace
int GetNumDevices() { int GetNumDevices() {
bool ipu_model = GetBoolEnv("POPLAR_IPUMODEL"); bool ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <popart/devicemanager.hpp> #include <vector>
namespace paddle { namespace paddle {
namespace platform { namespace platform {
......
...@@ -14,12 +14,17 @@ limitations under the License. */ ...@@ -14,12 +14,17 @@ limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_executor.h" #include "paddle/fluid/platform/device/ipu/ipu_executor.h"
using float16 = paddle::platform::float16; #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device/ipu/ipu_compiler.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
namespace {
// Get paddle prefix and popart postfix of weight states // Get paddle prefix and popart postfix of weight states
// Format: {popart_postfix, paddle_prefix} // Format: {popart_postfix, paddle_prefix}
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix( std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
...@@ -54,6 +59,35 @@ std::vector<std::pair<std::string, std::string>> GetOptPrePostfix( ...@@ -54,6 +59,35 @@ std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
return pre_post_fix; return pre_post_fix;
} }
class PdIArray final : public popart::IArray {
public:
explicit PdIArray(const Tensor *tensor) {
tensor_.ShareDataWith(*tensor);
for (int i = 0; i < tensor->dims().size(); ++i) {
shape_.push_back(tensor->dims().at(i));
}
}
public:
void *data() { return tensor_.data(); }
popart::DataType dataType() const {
return PhiDType2PopartDType(tensor_.dtype());
}
std::size_t rank() const { return tensor_.dims().size(); }
int64_t dim(size_t index) const { return tensor_.dims().at(index); }
std::size_t nelms() const {
return std::accumulate(shape_.begin(), shape_.end(),
static_cast<int64_t>(1), std::multiplies<int64_t>());
}
const popart::Shape shape() const { return shape_; }
private:
Tensor tensor_;
std::vector<int64_t> shape_;
};
} // namespace
Executor::~Executor() { Executor::~Executor() {
Detach(); Detach();
session_.reset(); session_.reset();
...@@ -110,15 +144,15 @@ void Executor::Run(const std::vector<const Tensor *> &inputs, ...@@ -110,15 +144,15 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
VLOG(10) << "enter Executor::Run"; VLOG(10) << "enter Executor::Run";
// inputs // inputs
std::map<popart::TensorId, popart::IArray &> popart_inputs; std::map<popart::TensorId, popart::IArray &> popart_inputs;
std::map<popart::TensorId, PaddleIArray> input_wrappers; std::map<popart::TensorId, PdIArray> input_wrappers;
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
auto tensor_id = compiler_resources_->inputs[i]; auto tensor_id = compiler_resources_->inputs[i];
input_wrappers.emplace(tensor_id, PaddleIArray(inputs[i])); input_wrappers.emplace(tensor_id, PdIArray(inputs[i]));
popart_inputs.emplace(tensor_id, input_wrappers.at(tensor_id)); popart_inputs.emplace(tensor_id, input_wrappers.at(tensor_id));
} }
// anchors // anchors
std::map<popart::TensorId, popart::IArray &> popart_anchors; std::map<popart::TensorId, popart::IArray &> popart_anchors;
std::map<popart::TensorId, PaddleIArray> anchor_wrappers; std::map<popart::TensorId, PdIArray> anchor_wrappers;
for (size_t i = 0; i < outputs.size(); i++) { for (size_t i = 0; i < outputs.size(); i++) {
auto tensor_id = compiler_resources_->outputs[i]; auto tensor_id = compiler_resources_->outputs[i];
// get dims & dtype from session // get dims & dtype from session
...@@ -140,10 +174,10 @@ void Executor::Run(const std::vector<const Tensor *> &inputs, ...@@ -140,10 +174,10 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
auto *tensor = outputs[i]; auto *tensor = outputs[i];
tensor->Resize(phi::make_ddim(output_shape)); tensor->Resize(phi::make_ddim(output_shape));
auto fetch_dtype = fetch_info.dataType(); auto fetch_dtype = fetch_info.dataType();
auto paddle_type = PopartType2VarType(fetch_dtype); auto paddle_type = PopartDType2VarType(fetch_dtype);
tensor->mutable_data(ctx.GetPlace(), tensor->mutable_data(ctx.GetPlace(),
framework::TransToPhiDataType(paddle_type)); framework::TransToPhiDataType(paddle_type));
anchor_wrappers.emplace(tensor_id, PaddleIArray(tensor)); anchor_wrappers.emplace(tensor_id, PdIArray(tensor));
popart_anchors.emplace(tensor_id, anchor_wrappers.at(tensor_id)); popart_anchors.emplace(tensor_id, anchor_wrappers.at(tensor_id));
} }
VLOG(10) << "Prepared inputs/anchors"; VLOG(10) << "Prepared inputs/anchors";
...@@ -203,16 +237,16 @@ void Executor::AcquireDevice() { ...@@ -203,16 +237,16 @@ void Executor::AcquireDevice() {
device_ = popart::DeviceManager::createDeviceManager().acquireDeviceById( device_ = popart::DeviceManager::createDeviceManager().acquireDeviceById(
device_id); device_id);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
device_, platform::errors::Unavailable( device_,
"Can't attach IPU in distribution, ipu_num = %d.", errors::Unavailable("Can't attach IPU in distribution, ipu_num = %d.",
RequestIpus(ipu_strategy_->num_ipus))); RequestIpus(ipu_strategy_->num_ipus)));
} else { } else {
device_ = device_ =
popart::DeviceManager::createDeviceManager().acquireAvailableDevice( popart::DeviceManager::createDeviceManager().acquireAvailableDevice(
RequestIpus(ipu_strategy_->num_ipus)); RequestIpus(ipu_strategy_->num_ipus));
PADDLE_ENFORCE_NOT_NULL(device_, platform::errors::Unavailable( PADDLE_ENFORCE_NOT_NULL(
"Can't attach IPU, ipu_num = %d.", device_, errors::Unavailable("Can't attach IPU, ipu_num = %d.",
RequestIpus(ipu_strategy_->num_ipus))); RequestIpus(ipu_strategy_->num_ipus)));
} }
VLOG(10) << "leave Executor::AcquireDevice"; VLOG(10) << "leave Executor::AcquireDevice";
} }
...@@ -260,13 +294,13 @@ void Executor::SetWeightsIO() { ...@@ -260,13 +294,13 @@ void Executor::SetWeightsIO() {
void Executor::ConvertWeights(bool align_to_popart) { void Executor::ConvertWeights(bool align_to_popart) {
for (auto weight_pair : executor_resources_->weights_and_opt_state) { for (auto weight_pair : executor_resources_->weights_and_opt_state) {
auto paddle_var = scope_->GetVar(weight_pair.second); auto paddle_var = scope_->GetVar(weight_pair.second);
auto paddle_var_dtype = PdDataType2PopartType( auto paddle_var_dtype = PhiDType2PopartDType(
paddle_var->GetMutable<framework::LoDTensor>()->dtype()); paddle_var->GetMutable<framework::LoDTensor>()->dtype());
PADDLE_ENFORCE_EQ((paddle_var_dtype == popart::DataType::FLOAT || PADDLE_ENFORCE_EQ((paddle_var_dtype == popart::DataType::FLOAT ||
paddle_var_dtype == popart::DataType::FLOAT16), paddle_var_dtype == popart::DataType::FLOAT16),
true, true,
platform::errors::InvalidArgument( errors::InvalidArgument(
"Currently, we only support FLOAT16 and FLOAT with " "Currently, we only support FLOAT16 and FLOAT with "
"Paddle, but received type is %s.", "Paddle, but received type is %s.",
paddle_var_dtype)); paddle_var_dtype));
...@@ -276,7 +310,7 @@ void Executor::ConvertWeights(bool align_to_popart) { ...@@ -276,7 +310,7 @@ void Executor::ConvertWeights(bool align_to_popart) {
PADDLE_ENFORCE_EQ((popart_var_dtype == popart::DataType::FLOAT || PADDLE_ENFORCE_EQ((popart_var_dtype == popart::DataType::FLOAT ||
popart_var_dtype == popart::DataType::FLOAT16), popart_var_dtype == popart::DataType::FLOAT16),
true, true,
platform::errors::InvalidArgument( errors::InvalidArgument(
"Currently, we only support FLOAT16 and FLOAT with " "Currently, we only support FLOAT16 and FLOAT with "
"popart, but received type is %s.", "popart, but received type is %s.",
popart_var_dtype)); popart_var_dtype));
...@@ -310,8 +344,8 @@ void Executor::ConvertWeights(bool align_to_popart) { ...@@ -310,8 +344,8 @@ void Executor::ConvertWeights(bool align_to_popart) {
num_elem * sizeof(float)); num_elem * sizeof(float));
} }
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(
"Convert Paddle FLOAT16 to popart FLOAT")); errors::Unimplemented("Convert Paddle FLOAT16 to popart FLOAT"));
} }
} }
} }
......
...@@ -22,17 +22,21 @@ limitations under the License. */ ...@@ -22,17 +22,21 @@ limitations under the License. */
#include <popart/tensorinfo.hpp> #include <popart/tensorinfo.hpp>
#include <popdist/popdist_poplar.hpp> #include <popdist/popdist_poplar.hpp>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device/ipu/ipu_compiler.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h" #include "paddle/fluid/platform/device/ipu/ipu_utils.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
} // namespace paddle
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
struct CompilerResources;
class IpuStrategy;
struct ExecutorResources { struct ExecutorResources {
// map<tensor_id, paddle_var_ptr> // map<tensor_id, paddle_var_ptr>
popart::WeightsIO weights_io; popart::WeightsIO weights_io;
...@@ -45,18 +49,18 @@ class Executor { ...@@ -45,18 +49,18 @@ class Executor {
Executor() = default; Executor() = default;
~Executor(); ~Executor();
// build popart session // Build popart session
void Prepare(const std::string &proto); void Prepare(const std::string &proto);
// run popart session // Run popart session
void Run(const std::vector<const Tensor *> &inputs, void Run(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs, const std::vector<Tensor *> &outputs,
const framework::ExecutionContext &ctx); const framework::ExecutionContext &ctx);
// sync weights from popart to paddle // Sync weights from popart to paddle
void WeightsToHost(); void WeightsToHost();
// detach IPU // Detach IPU
void Detach(); void Detach();
// Scope // Scope
...@@ -83,16 +87,16 @@ class Executor { ...@@ -83,16 +87,16 @@ class Executor {
void WeightsToPaddle(); void WeightsToPaddle();
private: private:
// not own // Not own
const Scope *scope_ = nullptr; const Scope *scope_ = nullptr;
const IpuStrategy *ipu_strategy_ = nullptr; const IpuStrategy *ipu_strategy_ = nullptr;
CompilerResources *compiler_resources_ = nullptr; CompilerResources *compiler_resources_ = nullptr;
// deviceinfo for popart session // Deviceinfo for popart session
std::shared_ptr<popart::DeviceInfo> device_; std::shared_ptr<popart::DeviceInfo> device_;
// popart session, where graph running // Popart session, where graph running
std::unique_ptr<popart::Session> session_; std::unique_ptr<popart::Session> session_;
// one OneSession means a graph // A ExecutorResources corresponds to a graph
std::unique_ptr<ExecutorResources> executor_resources_; std::unique_ptr<ExecutorResources> executor_resources_;
}; };
......
...@@ -316,8 +316,10 @@ IpuStrategy::IpuStrategy() { ...@@ -316,8 +316,10 @@ IpuStrategy::IpuStrategy() {
RegisterSetter(bool_options, "enable_half_partial", [&](bool value) { RegisterSetter(bool_options, "enable_half_partial", [&](bool value) {
if (value) { if (value) {
popart_options.partialsTypeMatMuls = "half"; popart_options.partialsTypeMatMuls = "half";
popart_options.convolutionOptions.insert({{"partialsType", "half"}});
} else { } else {
popart_options.partialsTypeMatMuls = "float"; popart_options.partialsTypeMatMuls = "float";
popart_options.convolutionOptions.insert({{"partialsType", "float"}});
} }
}); });
......
...@@ -13,133 +13,111 @@ See the License for the specific language governing permissions and ...@@ -13,133 +13,111 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_utils.h" #include "paddle/fluid/platform/device/ipu/ipu_utils.h"
#include <cmath> #include <cmath>
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
void* PaddleIArray::data() { return tensor_.data(); } const popart::DataType VarType2PopartDType(const VarType::Type type) {
popart::DataType PaddleIArray::dataType() const {
return PdDataType2PopartType(tensor_.dtype());
}
std::size_t PaddleIArray::rank() const { return tensor_.dims().size(); }
int64_t PaddleIArray::dim(size_t index) const {
return tensor_.dims().at(index);
}
std::size_t PaddleIArray::nelms() const {
return std::accumulate(shape_.begin(), shape_.end(), static_cast<int64_t>(1),
std::multiplies<int64_t>());
}
const popart::Shape PaddleIArray::shape() const { return shape_; }
popart::DataType VarType2PopartType(
const framework::proto::VarType::Type type) {
switch (type) { switch (type) {
case framework::proto::VarType::UINT8: case VarType::UINT8:
return popart::DataType::UINT8; return popart::DataType::UINT8;
case framework::proto::VarType::INT8: case VarType::INT8:
return popart::DataType::INT8; return popart::DataType::INT8;
case framework::proto::VarType::INT16: case VarType::INT16:
return popart::DataType::INT16; return popart::DataType::INT16;
case framework::proto::VarType::INT32: case VarType::INT32:
return popart::DataType::INT32; return popart::DataType::INT32;
case framework::proto::VarType::INT64: case VarType::INT64:
return popart::DataType::INT64; return popart::DataType::INT64;
case framework::proto::VarType::BOOL: case VarType::BOOL:
return popart::DataType::BOOL; return popart::DataType::BOOL;
case framework::proto::VarType::FP64: case VarType::FP64:
return popart::DataType::DOUBLE; return popart::DataType::DOUBLE;
case framework::proto::VarType::FP32: case VarType::FP32:
return popart::DataType::FLOAT; return popart::DataType::FLOAT;
case framework::proto::VarType::FP16: case VarType::FP16:
return popart::DataType::FLOAT16; return popart::DataType::FLOAT16;
case framework::proto::VarType::BF16: case VarType::BF16:
return popart::DataType::BFLOAT16; return popart::DataType::BFLOAT16;
case framework::proto::VarType::COMPLEX64: case VarType::COMPLEX64:
return popart::DataType::COMPLEX64; return popart::DataType::COMPLEX64;
case framework::proto::VarType::COMPLEX128: case VarType::COMPLEX128:
return popart::DataType::COMPLEX128; return popart::DataType::COMPLEX128;
default: default:
PADDLE_THROW(paddle::platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported Paddle var type.")); "Unsupported VarType::Type when converting to popart data type."));
} }
} }
popart::DataType PdDataType2PopartType( const popart::DataType PhiDType2PopartDType(const phi::DataType type) {
const paddle::experimental::DataType type) {
switch (type) { switch (type) {
case paddle::experimental::DataType::UINT8: case phi::DataType::UINT8:
return popart::DataType::UINT8; return popart::DataType::UINT8;
case paddle::experimental::DataType::INT8: case phi::DataType::INT8:
return popart::DataType::INT8; return popart::DataType::INT8;
case paddle::experimental::DataType::INT16: case phi::DataType::INT16:
return popart::DataType::INT16; return popart::DataType::INT16;
case paddle::experimental::DataType::INT32: case phi::DataType::INT32:
return popart::DataType::INT32; return popart::DataType::INT32;
case paddle::experimental::DataType::INT64: case phi::DataType::INT64:
return popart::DataType::INT64; return popart::DataType::INT64;
case paddle::experimental::DataType::BOOL: case phi::DataType::BOOL:
return popart::DataType::BOOL; return popart::DataType::BOOL;
case paddle::experimental::DataType::FLOAT64: case phi::DataType::FLOAT64:
return popart::DataType::DOUBLE; return popart::DataType::DOUBLE;
case paddle::experimental::DataType::FLOAT32: case phi::DataType::FLOAT32:
return popart::DataType::FLOAT; return popart::DataType::FLOAT;
case paddle::experimental::DataType::FLOAT16: case phi::DataType::FLOAT16:
return popart::DataType::FLOAT16; return popart::DataType::FLOAT16;
case paddle::experimental::DataType::BFLOAT16: case phi::DataType::BFLOAT16:
return popart::DataType::BFLOAT16; return popart::DataType::BFLOAT16;
case paddle::experimental::DataType::COMPLEX64: case phi::DataType::COMPLEX64:
return popart::DataType::COMPLEX64; return popart::DataType::COMPLEX64;
case paddle::experimental::DataType::COMPLEX128: case phi::DataType::COMPLEX128:
return popart::DataType::COMPLEX128; return popart::DataType::COMPLEX128;
default: default:
PADDLE_THROW(paddle::platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported Paddle data type.")); "Unsupported phi::DataType when converting to popart data type."));
} }
} }
framework::proto::VarType::Type PopartType2VarType( const VarType::Type PopartDType2VarType(const popart::DataType type) {
const popart::DataType type) {
switch (type) { switch (type) {
case popart::DataType::UINT8: case popart::DataType::UINT8:
return framework::proto::VarType::UINT8; return VarType::UINT8;
case popart::DataType::INT8: case popart::DataType::INT8:
return framework::proto::VarType::INT8; return VarType::INT8;
case popart::DataType::INT16: case popart::DataType::INT16:
return framework::proto::VarType::INT16; return VarType::INT16;
case popart::DataType::INT32: case popart::DataType::INT32:
return framework::proto::VarType::INT32; return VarType::INT32;
case popart::DataType::INT64: case popart::DataType::INT64:
return framework::proto::VarType::INT64; return VarType::INT64;
case popart::DataType::BOOL: case popart::DataType::BOOL:
return framework::proto::VarType::BOOL; return VarType::BOOL;
case popart::DataType::DOUBLE: case popart::DataType::DOUBLE:
return framework::proto::VarType::FP64; return VarType::FP64;
case popart::DataType::FLOAT: case popart::DataType::FLOAT:
return framework::proto::VarType::FP32; return VarType::FP32;
case popart::DataType::FLOAT16: case popart::DataType::FLOAT16:
return framework::proto::VarType::FP16; return VarType::FP16;
case popart::DataType::BFLOAT16: case popart::DataType::BFLOAT16:
return framework::proto::VarType::BF16; return VarType::BF16;
case popart::DataType::COMPLEX64: case popart::DataType::COMPLEX64:
return framework::proto::VarType::COMPLEX64; return VarType::COMPLEX64;
case popart::DataType::COMPLEX128: case popart::DataType::COMPLEX128:
return framework::proto::VarType::COMPLEX128; return VarType::COMPLEX128;
default: default:
PADDLE_THROW(paddle::platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported Paddle var type.")); "Unsupported popart::DataType when converting to var type."));
} }
} }
popart::DataType OnnxDtype2PopartType(const int type) { const popart::DataType OnnxDType2PopartType(const ONNXDataType type) {
auto dtype = static_cast<ONNXDataType>(type); switch (type) {
switch (dtype) {
case ONNXDataType::BOOL: case ONNXDataType::BOOL:
return popart::DataType::BOOL; return popart::DataType::BOOL;
case ONNXDataType::INT16: case ONNXDataType::INT16:
...@@ -166,12 +144,69 @@ popart::DataType OnnxDtype2PopartType(const int type) { ...@@ -166,12 +144,69 @@ popart::DataType OnnxDtype2PopartType(const int type) {
return popart::DataType::COMPLEX128; return popart::DataType::COMPLEX128;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported ONNX data type: %d.", dtype)); "Unsupported ONNXDataType when converting to popart data type."));
}
}
const ONNXDataType VarType2OnnxDType(const VarType::Type type) {
switch (type) {
case VarType::BOOL:
return ONNXDataType::BOOL;
case VarType::INT16:
return ONNXDataType::INT16;
case VarType::INT32:
return ONNXDataType::INT32;
case VarType::INT64:
return ONNXDataType::INT64;
case VarType::FP16:
return ONNXDataType::FLOAT16;
case VarType::FP32:
return ONNXDataType::FLOAT;
case VarType::FP64:
return ONNXDataType::DOUBLE;
case VarType::UINT8:
return ONNXDataType::UINT8;
case VarType::INT8:
return ONNXDataType::INT8;
case VarType::BF16:
return ONNXDataType::BFLOAT16;
case VarType::COMPLEX64:
return ONNXDataType::COMPLEX64;
case VarType::COMPLEX128:
return ONNXDataType::COMPLEX128;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported VarType::Type when converting to onnx data type."));
}
}
const std::string VarType2PopartStr(const VarType::Type type) {
switch (type) {
case VarType::UINT8:
return "UINT8";
case VarType::INT8:
return "INT8";
case VarType::INT16:
return "INT16";
case VarType::INT32:
return "INT32";
case VarType::INT64:
return "INT64";
case VarType::BOOL:
return "BOOL";
case VarType::FP64:
return "DOUBLE";
case VarType::FP32:
return "FLOAT";
case VarType::FP16:
return "FLOAT16";
default:
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported VarType::Type when converting to popart type string."));
} }
} }
// count num should > 0 const bool GetBoolEnv(const std::string& str) {
bool GetBoolEnv(std::string str) {
char* str_val = getenv(str.c_str()); char* str_val = getenv(str.c_str());
if (str_val == NULL) { if (str_val == NULL) {
return false; return false;
...@@ -184,8 +219,7 @@ bool GetBoolEnv(std::string str) { ...@@ -184,8 +219,7 @@ bool GetBoolEnv(std::string str) {
} }
} }
int RequestIpus(const int num_ipus) { const int RequestIpus(const int num_ipus) {
// num_ipus must be pow(2, n);
return std::pow(2, ceil(log2(num_ipus))); return std::pow(2, ceil(log2(num_ipus)));
} }
......
...@@ -19,155 +19,32 @@ limitations under the License. */ ...@@ -19,155 +19,32 @@ limitations under the License. */
#include <popart/tensorinfo.hpp> #include <popart/tensorinfo.hpp>
#include <popart/vendored/any.hpp> #include <popart/vendored/any.hpp>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
using float16 = paddle::platform::float16;
using Tensor = paddle::framework::Tensor;
using LoDTensor = paddle::framework::LoDTensor;
using Scope = paddle::framework::Scope;
using OpDesc = paddle::framework::OpDesc;
using Graph = paddle::framework::ir::Graph;
using Node = paddle::framework::ir::Node;
using BlockDesc = paddle::framework::BlockDesc;
using VarType = paddle::framework::proto::VarType;
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
using float16 = platform::float16;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using Scope = framework::Scope;
using OpDesc = framework::OpDesc;
using Graph = framework::ir::Graph;
using Node = framework::ir::Node;
using BlockDesc = framework::BlockDesc;
// onnx dtype
// https://github.com/onnx/onnx/blob/master/onnx/onnx-ml.proto3
enum ONNXDataType : int {
UNDEFINED = 0,
FLOAT = 1,
UINT8 = 2,
INT8 = 3,
UINT16 = 4,
INT16 = 5,
INT32 = 6,
INT64 = 7,
STRING = 8,
BOOL = 9,
FLOAT16 = 10,
DOUBLE = 11,
UINT32 = 12,
UINT64 = 13,
COMPLEX64 = 14,
COMPLEX128 = 15,
BFLOAT16 = 16
};
class PaddleIArray final : public popart::IArray {
public:
explicit PaddleIArray(const Tensor* tensor) {
tensor_.ShareDataWith(*tensor);
for (int i = 0; i < tensor->dims().size(); ++i) {
shape_.push_back(tensor->dims().at(i));
}
}
public:
void* data();
popart::DataType dataType() const;
std::size_t rank() const;
int64_t dim(size_t index) const;
std::size_t nelms() const;
const popart::Shape shape() const;
private:
Tensor tensor_;
std::vector<int64_t> shape_;
};
popart::DataType VarType2PopartType(const framework::proto::VarType::Type type);
popart::DataType PdDataType2PopartType(
const paddle::experimental::DataType type);
framework::proto::VarType::Type PopartType2VarType(const popart::DataType type);
popart::DataType OnnxDtype2PopartType(const int type);
bool GetBoolEnv(std::string str);
template <typename T>
std::unique_ptr<popart::NDArrayWrapper<T>> Tensor2IArray(const Tensor& tensor) {
auto dtype = PdDataType2PopartType(tensor.dtype());
auto shape = std::vector<int64_t>();
for (size_t i = 0; i < tensor.dims().size(); ++i) {
shape.push_back(tensor.dims().at(i));
}
popart::TensorInfo tensor_info(dtype, shape);
return std::make_unique<popart::NDArrayWrapper<T>>(
reinterpret_cast<T*>(tensor.data()), tensor_info);
}
template <typename T>
std::unique_ptr<popart::NDArrayWrapper<T>> LoDTensor2IArray(
LoDTensor const& lod_tensor) {
if (lod_tensor.lod().size() == 0) {
return Tensor2IArray<T>(lod_tensor);
} else {
PADDLE_THROW(
platform::errors::Unimplemented("LoDTensor2IArray is Unimplemented"));
}
}
template <typename T> template <typename T>
T GetSingleVarFromScope(const Scope* scope, const std::string& var_name) { T GetSingleVarFromScope(const Scope* scope, const std::string& var_name) {
auto var = scope->GetVar(var_name); auto var = scope->GetVar(var_name);
auto tensor = var->Get<framework::LoDTensor>(); auto tensor = var->Get<framework::LoDTensor>();
// check dtype is ?
return tensor.data<T>()[0]; return tensor.data<T>()[0];
} }
struct CustomOpAttrVisitor : public boost::static_visitor<void> {
explicit CustomOpAttrVisitor(std::map<std::string, popart::any>* attr,
const std::string& attr_name)
: attrs_(attr), attr_name_(attr_name) {}
mutable std::map<std::string, popart::any>* attrs_;
std::string attr_name_;
void operator()(int v) const { attrs_->emplace(attr_name_, v); }
void operator()(float v) const { attrs_->emplace(attr_name_, v); }
void operator()(const std::string& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(const std::vector<int>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(const std::vector<float>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(const std::vector<std::string>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(bool v) const { attrs_->emplace(attr_name_, v); }
void operator()(const std::vector<bool>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(BlockDesc* desc) const {
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported calling method for `BlockDesc` type."));
}
void operator()(const std::vector<BlockDesc*>& v) const {
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported calling method for `BlockDesc` type."));
}
void operator()(int64_t v) const { attrs_->emplace(attr_name_, v); }
void operator()(const std::vector<int64_t>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(const std::vector<double>& v) const {
attrs_->emplace(attr_name_, v);
}
void operator()(boost::blank) const {
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported calling method for `boost::blank` type."));
}
};
struct IpuCustomOpIdentifier { struct IpuCustomOpIdentifier {
IpuCustomOpIdentifier(const std::string& _paddle_op, IpuCustomOpIdentifier(const std::string& _paddle_op,
const std::string& _popart_op, const std::string& _popart_op,
...@@ -185,51 +62,44 @@ struct IpuCustomOpIdentifier { ...@@ -185,51 +62,44 @@ struct IpuCustomOpIdentifier {
popart::OperatorIdentifier popart_op; popart::OperatorIdentifier popart_op;
}; };
struct ConstantOpAttrVisitor : public boost::static_visitor<void> { // Onnx dtype
explicit ConstantOpAttrVisitor(framework::LoDTensor* tensor, // https://github.com/onnx/onnx/blob/master/onnx/onnx-ml.proto3
framework::proto::VarType::Type dtype) enum ONNXDataType : int {
: tensor_(tensor), dtype_(dtype) {} UNDEFINED = 0,
framework::LoDTensor* tensor_; FLOAT = 1,
framework::proto::VarType::Type dtype_; UINT8 = 2,
INT8 = 3,
void operator()(const std::vector<int>& vec) const { UINT16 = 4,
framework::TensorFromVector<int>(vec, tensor_); INT16 = 5,
} INT32 = 6,
void operator()(const std::vector<float>& vec) const { INT64 = 7,
if (dtype_ == framework::proto::VarType::FP16) { STRING = 8,
std::vector<float16> vec_fp16; BOOL = 9,
std::transform(vec.begin(), vec.end(), std::back_inserter(vec_fp16), FLOAT16 = 10,
[](float f) -> float16 { return float16(f); }); DOUBLE = 11,
framework::TensorFromVector<float16>(vec_fp16, tensor_); UINT32 = 12,
} else { UINT64 = 13,
framework::TensorFromVector<float>(vec, tensor_); COMPLEX64 = 14,
} COMPLEX128 = 15,
} BFLOAT16 = 16
void operator()(const std::vector<bool>& vec) const {
framework::TensorFromVector<bool>(vec, tensor_);
}
void operator()(const std::vector<int64_t>& vec) const {
framework::TensorFromVector<int64_t>(vec, tensor_);
}
void operator()(const std::vector<double>& vec) const {
framework::TensorFromVector<double>(vec, tensor_);
}
void RaiseError() const {
PADDLE_THROW(
platform::errors::InvalidArgument("Constant value must be a vector"));
}
void operator()(int v) const { RaiseError(); }
void operator()(float v) const { RaiseError(); }
void operator()(const std::string& v) const { RaiseError(); }
void operator()(const std::vector<std::string>& v) const { RaiseError(); }
void operator()(bool v) const { RaiseError(); }
void operator()(BlockDesc* desc) const { RaiseError(); }
void operator()(const std::vector<BlockDesc*>& v) const { RaiseError(); }
void operator()(int64_t v) const { RaiseError(); }
void operator()(boost::blank) const { RaiseError(); }
}; };
int RequestIpus(const int num_ipus); // VarType::Type to popart::DataType
const popart::DataType VarType2PopartDType(const VarType::Type type);
// phi::DataType to popart::DataType
const popart::DataType PhiDType2PopartDType(const phi::DataType type);
// popart::DataType to VarType::Type
const VarType::Type PopartDType2VarType(const popart::DataType type);
// ONNXDataType to popart::DataType
const popart::DataType OnnxDType2PopartType(const ONNXDataType type);
// VarType::Type to ONNXDataType
const ONNXDataType VarType2OnnxDType(const VarType::Type type);
// VarType::Type to String in Popart
const std::string VarType2PopartStr(const VarType::Type type);
// Get bool from envirnment varaible
const bool GetBoolEnv(const std::string& str);
// Request number of ipus must be pow(2, n)
const int RequestIpus(const int num_ipus);
} // namespace ipu } // namespace ipu
} // namespace platform } // namespace platform
......
...@@ -56,15 +56,15 @@ Node *gelu_handler(Graph *graph, Node *node) { ...@@ -56,15 +56,15 @@ Node *gelu_handler(Graph *graph, Node *node) {
auto sqrt2 = CreateConst(graph, node, {}, {}, auto sqrt2 = CreateConst(graph, node, {}, {},
{{"value", std::vector<float>{1.4142135623730951}}, {{"value", std::vector<float>{1.4142135623730951}},
{"dims", std::vector<int64_t>{1}}, {"dims", std::vector<int64_t>{1}},
{"dtype", GetOutputVarDtype(node)}}); {"dtype", GetOutputVarDType(node)}});
auto zero_point_five = auto zero_point_five =
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{0.5}}, CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{0.5}},
{"dims", std::vector<int64_t>{1}}, {"dims", std::vector<int64_t>{1}},
{"dtype", GetOutputVarDtype(node)}}); {"dtype", GetOutputVarDType(node)}});
auto one = auto one =
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{1}}, CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{1}},
{"dims", std::vector<int64_t>{1}}, {"dims", std::vector<int64_t>{1}},
{"dtype", GetOutputVarDtype(node)}}); {"dtype", GetOutputVarDType(node)}});
auto div = auto div =
CreateBaseOp(graph, node, "popart_div", CreateBaseOp(graph, node, "popart_div",
{GetInputVarNode("X", node), sqrt2->outputs[0]}, {}, {}); {GetInputVarNode("X", node), sqrt2->outputs[0]}, {}, {});
......
...@@ -18,7 +18,6 @@ namespace paddle { ...@@ -18,7 +18,6 @@ namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
// This avoids the static initialisation order fiasco,
std::unordered_map<std::string, SymbolHandler> &SymbolHandlers() { std::unordered_map<std::string, SymbolHandler> &SymbolHandlers() {
static std::unordered_map<std::string, SymbolHandler> symbol_handlers; static std::unordered_map<std::string, SymbolHandler> symbol_handlers;
return symbol_handlers; return symbol_handlers;
...@@ -34,8 +33,6 @@ bool RegisterHandler(const std::string &symbol, const SymbolHandler &handler) { ...@@ -34,8 +33,6 @@ bool RegisterHandler(const std::string &symbol, const SymbolHandler &handler) {
return new_handler; return new_handler;
} }
// Return a pointer to a handler if one is registered for this kind of node or
// an empty std::function otherwise.
SymbolHandler GetHandler(const std::string &kind) { SymbolHandler GetHandler(const std::string &kind) {
auto it = SymbolHandlers().find(kind); auto it = SymbolHandlers().find(kind);
if (it != SymbolHandlers().end()) { if (it != SymbolHandlers().end()) {
...@@ -84,66 +81,6 @@ void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op, ...@@ -84,66 +81,6 @@ void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op,
} }
} }
const int VarType2OnnxDtype(const int type) {
auto dtype = static_cast<framework::proto::VarType::Type>(type);
switch (dtype) {
case framework::proto::VarType::BOOL:
return static_cast<int>(ONNXDataType::BOOL);
case framework::proto::VarType::INT16:
return static_cast<int>(ONNXDataType::INT16);
case framework::proto::VarType::INT32:
return static_cast<int>(ONNXDataType::INT32);
case framework::proto::VarType::INT64:
return static_cast<int>(ONNXDataType::INT64);
case framework::proto::VarType::FP16:
return static_cast<int>(ONNXDataType::FLOAT16);
case framework::proto::VarType::FP32:
return static_cast<int>(ONNXDataType::FLOAT);
case framework::proto::VarType::FP64:
return static_cast<int>(ONNXDataType::DOUBLE);
case framework::proto::VarType::UINT8:
return static_cast<int>(ONNXDataType::UINT8);
case framework::proto::VarType::INT8:
return static_cast<int>(ONNXDataType::INT8);
case framework::proto::VarType::BF16:
return static_cast<int>(ONNXDataType::BFLOAT16);
case framework::proto::VarType::COMPLEX64:
return static_cast<int>(ONNXDataType::COMPLEX64);
case framework::proto::VarType::COMPLEX128:
return static_cast<int>(ONNXDataType::COMPLEX128);
default:
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported data type: %d.", dtype));
}
}
const std::string VarType2PopStr(const int type) {
auto dtype = static_cast<framework::proto::VarType::Type>(type);
switch (dtype) {
case framework::proto::VarType::UINT8:
return "UINT8";
case framework::proto::VarType::INT8:
return "INT8";
case framework::proto::VarType::INT16:
return "INT16";
case framework::proto::VarType::INT32:
return "INT32";
case framework::proto::VarType::INT64:
return "INT64";
case framework::proto::VarType::BOOL:
return "BOOL";
case framework::proto::VarType::FP64:
return "DOUBLE";
case framework::proto::VarType::FP32:
return "FLOAT";
case framework::proto::VarType::FP16:
return "FLOAT16";
default:
PADDLE_THROW(
paddle::platform::errors::Unavailable("Unsupported data type."));
}
}
Node *GetInputVarNode(const std::string &input_name, const Node *op_node, Node *GetInputVarNode(const std::string &input_name, const Node *op_node,
const int id) { const int id) {
auto var_name = op_node->Op()->Input(input_name).at(id); auto var_name = op_node->Op()->Input(input_name).at(id);
...@@ -180,7 +117,7 @@ const bool is_float_equal(float a, float b, float eps) { ...@@ -180,7 +117,7 @@ const bool is_float_equal(float a, float b, float eps) {
return std::fabs(a - b) <= eps; return std::fabs(a - b) <= eps;
} }
const int GetOutputVarDtype(const Node *node, const std::string &output_name) { const int GetOutputVarDType(const Node *node, const std::string &output_name) {
auto out_node = GetOutputVarNode(output_name, node); auto out_node = GetOutputVarNode(output_name, node);
PADDLE_ENFORCE_NOT_NULL(out_node, platform::errors::Unavailable( PADDLE_ENFORCE_NOT_NULL(out_node, platform::errors::Unavailable(
"Node's out node does not exist.")); "Node's out node does not exist."));
...@@ -188,7 +125,7 @@ const int GetOutputVarDtype(const Node *node, const std::string &output_name) { ...@@ -188,7 +125,7 @@ const int GetOutputVarDtype(const Node *node, const std::string &output_name) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::Unavailable("Node is not a variable.")); var, platform::errors::Unavailable("Node is not a variable."));
auto proto_var_type = var->GetDataType(); auto proto_var_type = var->GetDataType();
return VarType2OnnxDtype(proto_var_type); return static_cast<int>(VarType2OnnxDType(proto_var_type));
} }
} // namespace ipu } // namespace ipu
......
...@@ -68,9 +68,6 @@ void ClearNode(Node *node); ...@@ -68,9 +68,6 @@ void ClearNode(Node *node);
void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op, void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op,
bool override = false); bool override = false);
const int VarType2OnnxDtype(const int type);
const std::string VarType2PopStr(const int type);
Node *GetInputVarNode(const std::string &input_name, const Node *op_node, Node *GetInputVarNode(const std::string &input_name, const Node *op_node,
const int id = 0); const int id = 0);
Node *GetOutputVarNode(const std::string &output_name, const Node *op_node, Node *GetOutputVarNode(const std::string &output_name, const Node *op_node,
...@@ -81,7 +78,7 @@ Node *GetOutputVarNodeByVarName(const std::string &var_name, ...@@ -81,7 +78,7 @@ Node *GetOutputVarNodeByVarName(const std::string &var_name,
const Node *op_node); const Node *op_node);
const bool is_float_equal(float a, float b, float eps = 1e-8); const bool is_float_equal(float a, float b, float eps = 1e-8);
const int GetOutputVarDtype(const Node *node, const int GetOutputVarDType(const Node *node,
const std::string &output_name = "Out"); const std::string &output_name = "Out");
} // namespace ipu } // namespace ipu
......
...@@ -28,6 +28,14 @@ Node *equal_handler(Graph *graph, Node *node) { ...@@ -28,6 +28,14 @@ Node *equal_handler(Graph *graph, Node *node) {
return new_node; return new_node;
} }
Node *not_equal_handler(Graph *graph, Node *node) {
auto equal_node = CreateBaseOp(
graph, node, "popart_equal",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)}, {});
return CreateBaseOp(graph, node, "popart_logical_not",
{equal_node->outputs[0]}, node->outputs, {});
}
Node *logical_not_handler(Graph *graph, Node *node) { Node *logical_not_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_logical_not", return CreateBaseOp(graph, node, "popart_logical_not",
{GetInputVarNode("X", node)}, {GetInputVarNode("X", node)},
...@@ -64,6 +72,7 @@ Node *less_than_handler(Graph *graph, Node *node) { ...@@ -64,6 +72,7 @@ Node *less_than_handler(Graph *graph, Node *node) {
} // namespace paddle } // namespace paddle
REGISTER_HANDLER(equal, equal_handler); REGISTER_HANDLER(equal, equal_handler);
REGISTER_HANDLER(not_equal, not_equal_handler);
REGISTER_HANDLER(logical_not, logical_not_handler); REGISTER_HANDLER(logical_not, logical_not_handler);
REGISTER_HANDLER(logical_or, logical_or_handler); REGISTER_HANDLER(logical_or, logical_or_handler);
REGISTER_HANDLER(logical_and, logical_and_handler); REGISTER_HANDLER(logical_and, logical_and_handler);
......
...@@ -41,7 +41,7 @@ Node *pow_handler(Graph *graph, Node *node) { ...@@ -41,7 +41,7 @@ Node *pow_handler(Graph *graph, Node *node) {
// Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow) // Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow)
auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor")); auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor"));
auto attrs = auto attrs =
MakeConstAttrMapFromValue<float>(value_, {1}, GetOutputVarDtype(node)); MakeConstAttrMapFromValue<float>(value_, {1}, GetOutputVarDType(node));
auto new_node_const = CreateConst(graph, node, {}, {}, attrs); auto new_node_const = CreateConst(graph, node, {}, {}, attrs);
return CreateBaseOp(graph, node, "popart_pow", {GetInputVarNode("X", node), return CreateBaseOp(graph, node, "popart_pow", {GetInputVarNode("X", node),
...@@ -134,7 +134,7 @@ Node *matmul_handler(Graph *graph, Node *node) { ...@@ -134,7 +134,7 @@ Node *matmul_handler(Graph *graph, Node *node) {
} else { } else {
auto o_node = auto o_node =
CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node}, {}); CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node}, {});
auto attr = MakeConstAttrMapFromValue(alpha, {1}, GetOutputVarDtype(node)); auto attr = MakeConstAttrMapFromValue(alpha, {1}, GetOutputVarDType(node));
auto const_node = CreateConst(graph, node, {}, {}, attr); auto const_node = CreateConst(graph, node, {}, {}, attr);
return CreateBaseOp(graph, node, "popart_mul", return CreateBaseOp(graph, node, "popart_mul",
{o_node->outputs[0], const_node->outputs[0]}, {o_node->outputs[0], const_node->outputs[0]},
...@@ -299,6 +299,80 @@ Node *cross_entropy2_handler(Graph *graph, Node *node) { ...@@ -299,6 +299,80 @@ Node *cross_entropy2_handler(Graph *graph, Node *node) {
} }
} }
Node *softmax_with_cross_entropy_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index"));
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
auto soft_label = BOOST_GET_CONST(bool, op->GetAttr("soft_label"));
if (soft_label) {
PADDLE_THROW(platform::errors::InvalidArgument(
"soft_label is not supported yet in IPU"));
}
Node *new_cast = nullptr;
if (GetInputVarNode("Label", node)->Var()->GetDataType() ==
framework::proto::VarType::INT32) {
new_cast = GetInputVarNode("Label", node);
} else {
auto new_cast = CreateCast(graph, node, {GetInputVarNode("Label", node)},
{}, framework::proto::VarType::INT32);
new_cast = new_cast->outputs[0];
}
auto softmax_node = CreateSoftmaxOpset11(
graph, node, {GetInputVarNode("Logits", node)}, {}, axis);
auto label_shape_ = GetInputVarNode("Label", node)->Var()->GetShape();
if (label_shape_[label_shape_.size() - 1] != 1) {
auto log = CreateBaseOp(graph, node, "popart_log",
{softmax_node->outputs[0]}, {}, {});
// softmax_with_cross_entropy is split to several ops in python.
// reduction is not needed here.
return CreateBaseOp(
graph, node, "popart_nllloss_v2", {log->outputs[0], new_cast},
{GetOutputVarNode("Loss", node)},
{
{"reduction", 2}, // popart::ReductionType::NoReduction
{"ignoreIndex", ignoreIndex},
{"inputIsLogProbability", true},
});
} else {
std::vector<int64_t> new_shape_{label_shape_[0]};
auto const_before_loss = CreateBaseOp(
graph, node, "popart_constant", {}, {},
{{"value", new_shape_},
{"dims",
std::vector<int64_t>{static_cast<int64_t>(new_shape_.size())}},
{"dtype", ONNXDataType::INT64}});
auto reshape_before_loss =
CreateBaseOp(graph, node, "popart_reshape",
{new_cast, const_before_loss->outputs[0]}, {}, {});
auto log = CreateBaseOp(graph, node, "popart_log",
{softmax_node->outputs[0]}, {}, {});
auto nllloss = CreateBaseOp(
graph, node, "popart_nllloss_v2",
{log->outputs[0], reshape_before_loss->outputs[0]}, {},
{
{"reduction", 2}, // popart::ReductionType::NoReduction
{"ignoreIndex", ignoreIndex},
{"inputIsLogProbability", true},
});
auto const_after_loss = CreateBaseOp(
graph, node, "popart_constant", {}, {},
{{"value", label_shape_},
{"dims",
std::vector<int64_t>{static_cast<int64_t>(label_shape_.size())}},
{"dtype", ONNXDataType::INT64}});
auto reshape_after_loss =
CreateBaseOp(graph, node, "popart_reshape",
{nllloss->outputs[0], const_after_loss->outputs[0]},
{GetOutputVarNode("Loss", node)}, {});
return reshape_after_loss;
}
}
Node *cumsum_handler(Graph *graph, Node *node) { Node *cumsum_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto exclusive = BOOST_GET_CONST(bool, op->GetAttr("exclusive")); auto exclusive = BOOST_GET_CONST(bool, op->GetAttr("exclusive"));
...@@ -378,6 +452,8 @@ REGISTER_HANDLER(matmul, matmul_handler); ...@@ -378,6 +452,8 @@ REGISTER_HANDLER(matmul, matmul_handler);
REGISTER_HANDLER(sum, sum_handler); REGISTER_HANDLER(sum, sum_handler);
REGISTER_HANDLER(softmax, softmax_handler); REGISTER_HANDLER(softmax, softmax_handler);
REGISTER_HANDLER(scale, scale_handler); REGISTER_HANDLER(scale, scale_handler);
REGISTER_HANDLER(softmax_with_cross_entropy,
softmax_with_cross_entropy_handler);
REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler); REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler);
REGISTER_HANDLER(cumsum, cumsum_handler); REGISTER_HANDLER(cumsum, cumsum_handler);
REGISTER_HANDLER(matmul_v2, matmul_v2_handler); REGISTER_HANDLER(matmul_v2, matmul_v2_handler);
......
...@@ -299,7 +299,7 @@ Node *dropout_handler(Graph *graph, Node *node) { ...@@ -299,7 +299,7 @@ Node *dropout_handler(Graph *graph, Node *node) {
CreateConst(graph, node, {}, {}, CreateConst(graph, node, {}, {},
{{"value", std::vector<float>{1 - dropout_prob_}}, {{"value", std::vector<float>{1 - dropout_prob_}},
{"dims", std::vector<int64_t>{1}}, {"dims", std::vector<int64_t>{1}},
{"dtype", GetOutputVarDtype(node)}}); {"dtype", GetOutputVarDType(node)}});
return CreateBaseOp(graph, node, "popart_mul", return CreateBaseOp(graph, node, "popart_mul",
{GetInputVarNode("X", node), scale->outputs[0]}, {GetInputVarNode("X", node), scale->outputs[0]},
{GetOutputVarNode("Out", node)}, {}); {GetOutputVarNode("Out", node)}, {});
......
...@@ -124,7 +124,7 @@ Node *CreateConst(Graph *graph, Node *node, const std::vector<Node *> &inputs, ...@@ -124,7 +124,7 @@ Node *CreateConst(Graph *graph, Node *node, const std::vector<Node *> &inputs,
Node *CreateCast(Graph *graph, Node *node, const std::vector<Node *> &inputs, Node *CreateCast(Graph *graph, Node *node, const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs, const int otype) { const std::vector<Node *> &outputs, const int otype) {
auto to = VarType2PopStr(otype); auto to = VarType2PopartStr(static_cast<VarType::Type>(otype));
return CreateBaseOp(graph, node, "popart_cast", inputs, outputs, return CreateBaseOp(graph, node, "popart_cast", inputs, outputs,
{{"to", to}}); {{"to", to}});
} }
......
...@@ -23,12 +23,14 @@ namespace { ...@@ -23,12 +23,14 @@ namespace {
Node *fill_constant_handler(Graph *graph, Node *node) { Node *fill_constant_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
if (!op->Input("ShapeTensor").empty()) { auto op_inputs = op->Inputs();
if (op_inputs.find("ShapeTensor") != op_inputs.end() &&
!op->Input("ShapeTensor").empty()) {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("op fill_constant with ShapeTensor")); platform::errors::Unimplemented("op fill_constant with ShapeTensor"));
} }
auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype")); auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto dtype = VarType2OnnxDtype(dtype_); auto dtype = VarType2OnnxDType(static_cast<VarType::Type>(dtype_));
auto dims = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("shape")); auto dims = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("shape"));
auto value_ = BOOST_GET_CONST(float, op->GetAttr("value")); auto value_ = BOOST_GET_CONST(float, op->GetAttr("value"));
size_t size = 1; size_t size = 1;
...@@ -37,19 +39,20 @@ Node *fill_constant_handler(Graph *graph, Node *node) { ...@@ -37,19 +39,20 @@ Node *fill_constant_handler(Graph *graph, Node *node) {
} }
Attribute value; Attribute value;
switch (dtype_) { switch (dtype_) {
case framework::proto::VarType::FP32: case VarType::FP16:
case VarType::FP32:
value = std::vector<float>(size, value_); value = std::vector<float>(size, value_);
break; break;
case framework::proto::VarType::FP64: case VarType::FP64:
value = std::vector<double>(size, value_); value = std::vector<double>(size, value_);
break; break;
case framework::proto::VarType::INT32: case VarType::INT32:
value = std::vector<int>(size, value_); value = std::vector<int>(size, value_);
break; break;
case framework::proto::VarType::INT64: case VarType::INT64:
value = std::vector<int64_t>(size, value_); value = std::vector<int64_t>(size, value_);
break; break;
case framework::proto::VarType::BOOL: case VarType::BOOL:
value = std::vector<bool>(size, value_); value = std::vector<bool>(size, value_);
break; break;
default: default:
...@@ -66,7 +69,7 @@ Node *gaussian_random_handler(Graph *graph, Node *node) { ...@@ -66,7 +69,7 @@ Node *gaussian_random_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto shape = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("shape")); auto shape = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("shape"));
auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype")); auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto dtype = VarType2OnnxDtype(dtype_); auto dtype = VarType2OnnxDType(static_cast<VarType::Type>(dtype_));
auto mean = BOOST_GET_CONST(float, op->GetAttr("mean")); auto mean = BOOST_GET_CONST(float, op->GetAttr("mean"));
auto scale = BOOST_GET_CONST(float, op->GetAttr("std")); auto scale = BOOST_GET_CONST(float, op->GetAttr("std"));
// seed not work // seed not work
...@@ -86,7 +89,7 @@ Node *uniform_random_handler(Graph *graph, Node *node) { ...@@ -86,7 +89,7 @@ Node *uniform_random_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto shape = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("shape")); auto shape = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("shape"));
auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype")); auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto dtype = VarType2OnnxDtype(dtype_); auto dtype = VarType2OnnxDType(static_cast<VarType::Type>(dtype_));
auto high = BOOST_GET_CONST(float, op->GetAttr("max")); auto high = BOOST_GET_CONST(float, op->GetAttr("max"));
auto low = BOOST_GET_CONST(float, op->GetAttr("min")); auto low = BOOST_GET_CONST(float, op->GetAttr("min"));
// seed not work // seed not work
...@@ -172,9 +175,21 @@ Node *squeeze_handler(Graph *graph, Node *node) { ...@@ -172,9 +175,21 @@ Node *squeeze_handler(Graph *graph, Node *node) {
Node *cast_handler(Graph *graph, Node *node) { Node *cast_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto otype = BOOST_GET_CONST(int, op->GetAttr("out_dtype")); auto otype = BOOST_GET_CONST(int, op->GetAttr("out_dtype"));
auto new_node_cast = auto new_node = CreateCast(graph, node, node->inputs, node->outputs, otype);
CreateCast(graph, node, node->inputs, node->outputs, otype); // Cast op created in mixed-precison has no pipline attrs
return new_node_cast; auto &prev_nodes = node->inputs.front()->inputs;
if (!prev_nodes.empty()) {
auto *prev_op = prev_nodes.front()->Op();
if (!new_node->Op()->HasAttr(sIpuIndexAttr) &&
prev_op->HasAttr(sIpuIndexAttr)) {
CopyOpAttr(sIpuIndexAttr, prev_op, new_node->Op());
}
if (!new_node->Op()->HasAttr(sIpuStageAttr) &&
prev_op->HasAttr(sIpuStageAttr)) {
CopyOpAttr(sIpuStageAttr, prev_op, new_node->Op());
}
}
return new_node;
} }
Node *lookup_table_op_handler(Graph *graph, Node *node, Node *lookup_table_op_handler(Graph *graph, Node *node,
...@@ -192,7 +207,7 @@ Node *lookup_table_op_handler(Graph *graph, Node *node, ...@@ -192,7 +207,7 @@ Node *lookup_table_op_handler(Graph *graph, Node *node,
auto concat_const = auto concat_const =
CreateConst(graph, node, {}, {}, {{"value", const_value_}, CreateConst(graph, node, {}, {}, {{"value", const_value_},
{"dims", const_shape_}, {"dims", const_shape_},
{"dtype", GetOutputVarDtype(node)}}); {"dtype", GetOutputVarDType(node)}});
auto axes = auto axes =
CreateConst(graph, node, {}, {}, {{"value", std::vector<int64_t>{0}}, CreateConst(graph, node, {}, {}, {{"value", std::vector<int64_t>{0}},
{"dims", std::vector<int64_t>{1}}, {"dims", std::vector<int64_t>{1}},
...@@ -397,7 +412,7 @@ Node *expand_handler(Graph *graph, Node *node) { ...@@ -397,7 +412,7 @@ Node *expand_handler(Graph *graph, Node *node) {
// cast to int64 // cast to int64
expand_times = expand_times =
CreateCast(graph, node, {GetInputVarNode("ExpandTimes", node)}, {}, CreateCast(graph, node, {GetInputVarNode("ExpandTimes", node)}, {},
framework::proto::VarType::INT64); VarType::INT64);
} else { } else {
auto expand_times_i32 = auto expand_times_i32 =
BOOST_GET_CONST(std::vector<int>, op->GetAttr("expand_times")); BOOST_GET_CONST(std::vector<int>, op->GetAttr("expand_times"));
...@@ -423,27 +438,28 @@ Node *assign_handler(Graph *graph, Node *node) { ...@@ -423,27 +438,28 @@ Node *assign_handler(Graph *graph, Node *node) {
Node *assign_value_handler(Graph *graph, Node *node) { Node *assign_value_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype")); auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto dtype = VarType2OnnxDtype(dtype_); auto dtype = VarType2OnnxDType(static_cast<VarType::Type>(dtype_));
auto dims_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("shape")); auto dims_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("shape"));
std::vector<int64_t> dims(dims_.begin(), dims_.end()); std::vector<int64_t> dims(dims_.begin(), dims_.end());
Attribute values; Attribute values;
std::string value_name; std::string value_name;
switch (dtype_) { switch (dtype_) {
case framework::proto::VarType::BOOL: { case VarType::BOOL: {
value_name = "bool_values"; value_name = "bool_values";
auto vec_int = BOOST_GET_CONST(std::vector<int>, op->GetAttr(value_name)); auto vec_int = BOOST_GET_CONST(std::vector<int>, op->GetAttr(value_name));
std::vector<bool> vec_bool(vec_int.begin(), vec_int.end()); std::vector<bool> vec_bool(vec_int.begin(), vec_int.end());
values = vec_bool; values = vec_bool;
} break; } break;
case framework::proto::VarType::INT32: case VarType::INT32:
value_name = "int32_values"; value_name = "int32_values";
values = BOOST_GET_CONST(std::vector<int>, op->GetAttr(value_name)); values = BOOST_GET_CONST(std::vector<int>, op->GetAttr(value_name));
break; break;
case framework::proto::VarType::FP32: case VarType::FP16:
case VarType::FP32:
value_name = "fp32_values"; value_name = "fp32_values";
values = BOOST_GET_CONST(std::vector<float>, op->GetAttr(value_name)); values = BOOST_GET_CONST(std::vector<float>, op->GetAttr(value_name));
break; break;
case framework::proto::VarType::INT64: case VarType::INT64:
value_name = "int64_values"; value_name = "int64_values";
values = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr(value_name)); values = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr(value_name));
break; break;
...@@ -463,39 +479,40 @@ Node *fill_any_like_handler(Graph *graph, Node *node) { ...@@ -463,39 +479,40 @@ Node *fill_any_like_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto value = BOOST_GET_CONST(float, op->GetAttr("value")); auto value = BOOST_GET_CONST(float, op->GetAttr("value"));
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape(); auto x_shape = GetInputVarNode("X", node)->Var()->GetShape();
auto dtype = BOOST_GET_CONST(int, op->GetAttr("dtype")); auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto x_dtype = static_cast<framework::proto::VarType::Type>(dtype); auto dtype = static_cast<VarType::Type>(dtype_);
size_t size = 1; size_t size = 1;
for (auto &dim : x_shape) { for (auto &dim : x_shape) {
size *= dim; size *= dim;
} }
Attribute out_value; Attribute out_value;
switch (x_dtype) { switch (dtype) {
case framework::proto::VarType::FP32: case VarType::FP16:
case VarType::FP32:
out_value = std::vector<float>(size, value); out_value = std::vector<float>(size, value);
break; break;
case framework::proto::VarType::FP64: case VarType::FP64:
out_value = std::vector<double>(size, value); out_value = std::vector<double>(size, value);
break; break;
case framework::proto::VarType::INT32: case VarType::INT32:
out_value = std::vector<int>(size, value); out_value = std::vector<int>(size, value);
break; break;
case framework::proto::VarType::INT64: case VarType::INT64:
out_value = std::vector<int64_t>(size, value); out_value = std::vector<int64_t>(size, value);
break; break;
case framework::proto::VarType::BOOL: case VarType::BOOL:
out_value = std::vector<int64_t>(size, value); out_value = std::vector<int64_t>(size, value);
break; break;
default: default:
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("fill_any_like dtype: %d", x_dtype)); platform::errors::Unimplemented("fill_any_like dtype: %d", dtype));
} }
return CreateConst(graph, node, node->inputs, node->outputs, return CreateConst(graph, node, node->inputs, node->outputs,
AttributeMap{ AttributeMap{
{"value", out_value}, {"value", out_value},
{"dims", x_shape}, {"dims", x_shape},
{"dtype", VarType2OnnxDtype(dtype)}, {"dtype", VarType2OnnxDType(dtype)},
}); });
} }
...@@ -538,8 +555,7 @@ Node *one_hot_v2_handler(Graph *graph, Node *node) { ...@@ -538,8 +555,7 @@ Node *one_hot_v2_handler(Graph *graph, Node *node) {
{"dims", std::vector<int64_t>{1}}, {"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT32}}); {"dtype", ONNXDataType::INT32}});
Node *value_tensor = nullptr; Node *value_tensor = nullptr;
if (GetOutputVarNode("Out", node)->Var()->GetDataType() == if (GetOutputVarNode("Out", node)->Var()->GetDataType() == VarType::FP16) {
framework::proto::VarType::FP16) {
value_tensor = value_tensor =
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{0, 1}}, CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{0, 1}},
{"dims", std::vector<int64_t>{2}}, {"dims", std::vector<int64_t>{2}},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册