未验证 提交 21b901cb 编写于 作者: C czr-gc 提交者: GitHub

[IPU]: add model_runtime backend support in IPU (#47363)

* feat(ipu): add model_runtime backend support in IPU.

* fix(ipu_executor): fix error message format.

* fix(ipu_executor): fix format.

* fix(ipu_executor): fix format again.

* fix(ipu_executor): fix format again.

* fix(ipu_executor): fix format again.
上级 981d1a10
......@@ -111,6 +111,16 @@ struct CastDataType {
out_begin,
CastDataTypeFunctor<InType, OutType>());
context->Wait();
#endif
#if defined(PADDLE_WITH_IPU)
} else if (platform::is_ipu_place(in_.place())) {
platform::Transform<phi::CPUContext> trans;
auto* context = static_cast<const phi::CPUContext*>(ctx_);
trans(*context,
in_begin,
in_end,
out_begin,
CastDataTypeFunctor<InType, OutType>());
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace framework {
namespace ir {
// msvc15 don't support constexpr in correct way.
#if !defined(_WIN32)
// static constexpr member implies inline since CXX17 and may cause multiple
// definition.
#if !defined(_WIN32) && (__cplusplus < 201703L)
constexpr char Node::kControlDepVarName[];
#else
const char Node::kControlDepVarName[] = "__control_var";
......
......@@ -64,7 +64,10 @@ class Node {
enum class Type { kOperation, kVariable };
enum class Dep { kSame = 0, kBefore = 1, kAfter = 2, kNoDep = 3 };
#if !defined(_WIN32) // msvc not support constexpr correctly.
// msvc not support constexpr correctly.
// static constexpr member implies inline since CXX17 and may cause multiple
// definition.
#if !defined(_WIN32) && (__cplusplus < 201703L)
static constexpr char kControlDepVarName[] = "__control_var";
#else
static const char kControlDepVarName[];
......
......@@ -23,7 +23,7 @@ if(WITH_IPU)
endforeach()
endforeach()
set(IPU_BACKEND_SRC "ipu_strategy.cc" "ipu_executor.cc" "ipu_compiler.cc"
set(IPU_BACKEND_SRC "ipu_strategy.cc" "ipu_compiler.cc" "ipu_executor.cc"
"ipu_backend.cc" "ipu_utils.cc")
set(IPU_INFO_SRC "ipu_info.cc" "ipu_device.cc")
......@@ -34,7 +34,20 @@ if(WITH_IPU)
cc_library(
ipu_backend
SRCS ${IPU_BACKEND_SRC}
DEPS popart-only graph graph_helper popdist popart_canonicalization)
DEPS popart-only
graph
graph_helper
popdist
popef
model_runtime
popart_canonicalization)
# magic here is model_runtime requires CXX17 and uses std::optional while popart uses CXX14
# and nonstd::optional. A manually macro setting is required here to solve symbol conflict.
target_compile_features(ipu_backend PRIVATE cxx_std_17)
target_compile_definitions(ipu_backend
PRIVATE optional_CONFIG_SELECT_OPTIONAL=1)
cc_library(
ipu_info
SRCS ${IPU_INFO_SRC}
......
......@@ -76,7 +76,11 @@ void IpuBackend::Run(const std::vector<const phi::DenseTensor*>& inputs,
const std::vector<phi::DenseTensor*>& outputs,
const framework::ExecutionContext& ctx) {
timer_->Start();
if (ipu_strategy_->enable_model_runtime_executor) {
executor_->RunPopef(inputs, outputs, ctx);
} else {
executor_->Run(inputs, outputs, ctx);
}
timer_->Pause();
VLOG(10) << "[IPU Run]: " << timer_->ElapsedMS() << " (ms)";
}
......
......@@ -14,9 +14,11 @@ limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_executor.h"
#include <chrono>
#include <popart/devicemanager.hpp>
#include <popdist/popdist_poplar.hpp>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device/ipu/ipu_compiler.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
......@@ -28,6 +30,13 @@ namespace ipu {
namespace {
model_runtime::AnchorCallbackPredicate PredFilterMain(
const model_runtime::Session *session) {
// Create predicates for binding Anchors from Main programs only
return model_runtime::predicate_factory::predProgramFlowMain(
session->model()->metadata.programFlow());
}
// Get paddle prefix and popart postfix of weight states
// Format: {popart_postfix, paddle_prefix}
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
......@@ -156,6 +165,10 @@ void Executor::Prepare(const std::string &proto) {
VLOG(10) << "Setting random seed to: " << ipu_strategy_->random_seed;
session_->setRandomSeed(ipu_strategy_->random_seed);
}
enable_model_runtime_executor_ = ipu_strategy_->enable_model_runtime_executor;
if (enable_model_runtime_executor_) {
PreparePopefSession();
}
}
void Executor::Run(const std::vector<const Tensor *> &inputs,
......@@ -234,6 +247,208 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
VLOG(10) << "Running...done";
}
void Executor::PreparePopefSession() {
VLOG(10) << "enter Executor::PreparePopefSession";
if (popef_session_) {
VLOG(10) << "popef: previous popef model is not released, reset resources.";
ResetPopef();
}
auto popef_model = PopartSessionToPopefModel(session_.get());
auto num_buffers = ipu_strategy_->num_buffers;
// convert timeout_ms to timeout_ns
const std::chrono::nanoseconds timeout_ns(
int64_t(ipu_strategy_->timeout_ms * 1000000));
// prepare popef session
model_runtime::SessionConfig config;
config.policy = model_runtime::LaunchPolicy::Immediate;
popef_session_ =
std::make_unique<model_runtime::Session>(popef_model, config);
// prepare queue_manager
auto timeout_cb = [this](model_runtime::InputRingBuffer *buffer) {
VLOG(10) << "ModelRuntmie timeout callback is called.";
std::unique_lock lock(this->queue_mutex_);
if (buffer->readAvailable()) {
return;
}
this->queue_manager_->flushAll();
};
queue_manager_ =
popef_session_->createQueueManager(num_buffers,
timeout_cb,
timeout_ns,
PredFilterMain(popef_session_.get()),
PredFilterMain(popef_session_.get()));
// prepare program
popef_session_->runLoadPrograms();
main_program_ = std::thread([&]() {
while (!stop_.load()) {
VLOG(13) << "popef: Run main program";
popef_session_->runMainPrograms();
}
});
// Detach device from popart session
Detach();
}
void Executor::RunPopef(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs,
const framework::ExecutionContext &ctx) {
VLOG(10) << "enter Executor::RunPopef";
auto input_names = ctx.InputNames("FeedList");
auto output_names = ctx.OutputNames("FetchList");
int batch_size = 0;
bool auto_batch = (ipu_strategy_->timeout_ms != 0);
auto tensor_check = [&](const Tensor *tensor,
const popef::TensorInfo &info,
int *batch_size,
Tensor *cast_tensor) {
// check dtype
auto popef_phi_dtype = PopefDtype2PhiDtype(info.dataType());
bool casted = false;
if (popef_phi_dtype != tensor->dtype()) {
// popart may do some implicit conversion, int64->int32 for example, cast
// is needed in some case.
VLOG(10) << "Cast paddle input type " << tensor->dtype() << " to "
<< popef_phi_dtype;
framework::TransDataType(
*tensor, PopefDType2VarType(info.dataType()), cast_tensor);
casted = true;
}
// check size
auto popef_input_shape = info.shape();
if (popef_input_shape.size() != tensor->dims().size()) {
PADDLE_THROW(
errors::Fatal("Incompatible size between paddle and popef."));
}
for (int i = 1; i < popef_input_shape.size(); ++i) {
PADDLE_ENFORCE_EQ(
popef_input_shape[i],
tensor->dims().at(i),
errors::InvalidArgument("Invalid tensor size at dim %s. "
"popef expecting %s but received %s ",
i,
popef_input_shape[i],
tensor->dims().at(i)));
}
// check batch_size
if (!auto_batch) {
// disable auto batching
PADDLE_ENFORCE_EQ(
popef_input_shape[0],
tensor->dims().at(0),
errors::InvalidArgument(
"Batch size doesn't equal between paddle and popef."));
} else {
// enable auto batching
bool is_single_batch = ipu_strategy_->micro_batch_size == 1;
if (*batch_size == 0) {
// retrieve batch_size
*batch_size = is_single_batch ? 1 : tensor->dims().at(0);
} else if (!is_single_batch) {
// input/output should have batch info when enable auto batch.
PADDLE_ENFORCE_EQ(*batch_size,
tensor->dims().at(0),
errors::InvalidArgument(
"batch size should be equal for each tensor"));
}
}
return casted;
};
const auto &session_inputs = popef_session_->getUserInputAnchors();
std::vector<Tensor> cast_tensor(inputs.size());
const auto &session_outputs = popef_session_->getUserOutputAnchors();
// ModelRuntime::Queue is not thread safety.
std::unique_lock lock(queue_mutex_);
for (size_t i = 0; i < inputs.size(); ++i) {
const auto &popef_input_name =
compiler_resources_->tensors.at(input_names[i]);
auto &elem_queue = queue_manager_->inputQueue(popef_input_name);
const auto &info = elem_queue.tensorInfo();
VLOG(10) << "popef: handle popef input: " << popef_input_name
<< " mapped with paddle " << input_names[i];
bool casted = tensor_check(inputs[i], info, &batch_size, &(cast_tensor[i]));
const void *data = casted ? cast_tensor[i].data() : inputs[i]->data();
const auto size =
casted ? cast_tensor[i].memory_size() : inputs[i]->memory_size();
elem_queue.enqueue(data, size, [popef_input_name]() {
VLOG(10) << "popef: enqueued data for input: " << popef_input_name;
});
}
std::vector<std::future<void>> finish_indicators;
finish_indicators.reserve(session_outputs.size());
for (size_t i = 0; i < session_outputs.size(); ++i) {
const auto &popef_output_name =
compiler_resources_->tensors.at(output_names[i]);
auto &out_queue = queue_manager_->outputQueue(popef_output_name);
const auto &info = out_queue.tensorInfo();
VLOG(10) << "popef: handle popef output: " << popef_output_name
<< " mapped with paddle " << output_names[i];
auto popef_dtype = info.dataType();
auto paddle_dtype = PopefDType2VarType(popef_dtype);
auto output_shape = info.shape();
if (auto_batch) {
if (output_shape[0] == ipu_strategy_->micro_batch_size) {
output_shape[0] = batch_size;
} else {
// shape of output must have batch info when when auto batch enabled
PADDLE_THROW(platform::errors::Unimplemented(
"Auto batch doesn't support the tensor with no batch info. "
"Expected batch size in output tensor: %d should equal to "
"micro batch size: %d. Please make sure batch size is set "
"correctly in both IPU program compiling and IpuStrategy.",
output_shape[0],
ipu_strategy_->micro_batch_size));
}
}
auto *tensor = outputs[i];
// resize output size to make data_ptr valid.
tensor->Resize(phi::make_ddim(output_shape));
tensor->mutable_data(ctx.GetPlace(),
framework::TransToPhiDataType(paddle_dtype));
const auto size = tensor->memory_size();
auto promise = std::make_shared<std::promise<void>>();
finish_indicators.emplace_back(promise->get_future());
out_queue.enqueue(tensor->data(), size, [popef_output_name, promise]() {
VLOG(10) << "popef: received output: " << popef_output_name;
promise->set_value();
});
}
lock.unlock();
// Synchronous waiting outputs. Asynchronous execution is not supported since
// python api calling is synchronous and output data is copied outside.
for (const auto &indicator : finish_indicators) {
indicator.wait();
}
}
void Executor::WeightsToHost() {
if (ipu_strategy_->is_training && session_) {
WeightsToPaddle();
......@@ -316,6 +531,32 @@ void Executor::Reset() {
Detach();
session_.reset();
executor_resources_.reset();
if (enable_model_runtime_executor_) {
ResetPopef();
}
}
void Executor::ResetPopef() {
VLOG(10) << "Reset popef resources.";
stop_.store(true);
if (queue_manager_) {
queue_manager_->disconnectAll();
}
if (main_program_.joinable()) {
const auto future = std::async(std::launch::async,
[this]() { this->main_program_.join(); });
if (future.wait_for(std::chrono::seconds(10)) ==
std::future_status::timeout) {
popef_session_->stop();
VLOG(10) << "popef: failed to wait for main program. Force stop popef "
"session.";
}
}
popef_session_.reset();
// reset stop back to false in case executor is reused.
stop_.store(false);
queue_manager_ = nullptr;
}
void Executor::SetWeightsIO() {
......
......@@ -14,6 +14,10 @@ limitations under the License. */
#pragma once
#include <mutex>
#include <model_runtime/ModelRunner.hpp>
#include <model_runtime/Tensor.hpp>
#include <popart/dataflow.hpp>
#include <popart/half.hpp>
#include <popart/names.hpp>
......@@ -57,6 +61,11 @@ class Executor {
const std::vector<Tensor *> &outputs,
const framework::ExecutionContext &ctx);
// Run popef session
void RunPopef(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs,
const framework::ExecutionContext &ctx);
// Sync weights from popart to paddle
void WeightsToHost();
......@@ -88,13 +97,15 @@ class Executor {
void ConvertWeights(bool);
void WeightsFromPaddle();
void WeightsToPaddle();
void PreparePopefSession();
void ResetPopef();
private:
// Not own
const Scope *scope_ = nullptr;
const IpuStrategy *ipu_strategy_ = nullptr;
CompilerResources *compiler_resources_ = nullptr;
bool compile_only_ = false;
model_runtime::QueueManager *queue_manager_ = nullptr;
// Deviceinfo for popart session
std::shared_ptr<popart::DeviceInfo> device_;
......@@ -102,6 +113,20 @@ class Executor {
std::unique_ptr<popart::Session> session_;
// A ExecutorResources corresponds to a graph
std::unique_ptr<ExecutorResources> executor_resources_;
// mutex to lock session run.
mutable std::mutex mutex_;
// popef session
std::unique_ptr<model_runtime::Session> popef_session_;
// popef execution threads
std::thread main_program_;
// mutex to lock popef queue
std::mutex queue_mutex_;
bool compile_only_ = false;
// indicate if popart/popef is used. Do not use the ipu_strategy which may
// deconstruct before executor and cause undefined behavior.
bool enable_model_runtime_executor_ = false;
std::atomic_bool stop_ = {false};
};
} // namespace ipu
......
......@@ -87,24 +87,26 @@ IpuStrategy::IpuStrategy() {
ADD_BOOL_OPTION(use_no_bias_optimizer);
ADD_BOOL_OPTION(enable_distribution);
ADD_BOOL_OPTION(scaled_optimizer_state);
ADD_BOOL_OPTION(is_dynamic);
ADD_BOOL_OPTION(enable_model_runtime_executor);
ADD_UINT64_OPTION(num_ipus);
ADD_UINT64_OPTION(batches_per_step);
ADD_UINT64_OPTION(micro_batch_size);
ADD_UINT64_OPTION(random_seed);
ADD_UINT64_OPTION(tiles_per_ipu);
ADD_UINT64_OPTION(num_buffers);
ADD_DOUBLE_OPTION(available_memory_proportion);
ADD_DOUBLE_OPTION(loss_scaling);
ADD_DOUBLE_OPTION(max_weight_norm);
ADD_DOUBLE_OPTION(timeout_ms);
// dy2static support
ADD_DOUBLE_OPTION(lr);
ADD_STRING_OPTION(accl1_type);
ADD_STRING_OPTION(accl2_type);
ADD_STRING_OPTION(accl3_type);
ADD_STRING_OPTION(onnx_dump_path);
ADD_STRING_OPTION(weight_decay_mode);
// dy2static support
ADD_DOUBLE_OPTION(lr);
ADD_BOOL_OPTION(is_dynamic);
#undef ADD_STRING_OPTION
#undef ADD_DOUBLE_OPTION
#undef ADD_UINT64_OPTION
......
......@@ -118,6 +118,16 @@ class IpuStrategy {
// whether in dynamic mode
bool is_dynamic = false;
// use popart executor as default. Enable model_runtime executor if set to
// true
bool enable_model_runtime_executor = false;
// buffer size for model_runtime_executor
int num_buffers = 10;
// timeout setting for model_runtime
double timeout_ms = 0.0;
public:
void AddBoolOption(const std::string &option, bool value);
void AddUint64Option(const std::string &option, std::uint64_t value);
......
......@@ -84,6 +84,32 @@ const popart::DataType PhiDType2PopartDType(const phi::DataType type) {
}
}
const phi::DataType PopefDtype2PhiDtype(const popef::DataType type) {
switch (type) {
case popef::DataType::U8:
return phi::DataType::UINT8;
case popef::DataType::S8:
return phi::DataType::INT8;
case popef::DataType::S16:
return phi::DataType::INT16;
case popef::DataType::S32:
return phi::DataType::INT32;
case popef::DataType::S64:
return phi::DataType::INT64;
case popef::DataType::BOOL:
return phi::DataType::BOOL;
case popef::DataType::F64:
return phi::DataType::FLOAT64;
case popef::DataType::F32:
return phi::DataType::FLOAT32;
case popef::DataType::F16:
return phi::DataType::FLOAT16;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported phi::DataType when converting to popef data type."));
}
}
const VarType::Type PopartDType2VarType(const popart::DataType type) {
switch (type) {
case popart::DataType::UINT8:
......@@ -116,6 +142,32 @@ const VarType::Type PopartDType2VarType(const popart::DataType type) {
}
}
const VarType::Type PopefDType2VarType(const popef::DataType type) {
switch (type) {
case popef::DataType::U8:
return VarType::UINT8;
case popef::DataType::S8:
return VarType::INT8;
case popef::DataType::S16:
return VarType::INT16;
case popef::DataType::S32:
return VarType::INT32;
case popef::DataType::S64:
return VarType::INT64;
case popef::DataType::BOOL:
return VarType::BOOL;
case popef::DataType::F64:
return VarType::FP64;
case popef::DataType::F32:
return VarType::FP32;
case popef::DataType::F16:
return VarType::FP16;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported popart::DataType when converting to var type."));
}
}
const popart::DataType OnnxDType2PopartType(const ONNXDataType type) {
switch (type) {
case ONNXDataType::BOOL:
......@@ -223,6 +275,16 @@ const int RequestIpus(const int num_ipus) {
return std::pow(2, ceil(log2(num_ipus)));
}
std::shared_ptr<popef::Model> PopartSessionToPopefModel(
popart::Session* session) {
VLOG(10) << "Converting popart session to popef model";
auto temp_stream = std::make_shared<std::stringstream>();
session->compileAndExport(*temp_stream);
auto reader = std::make_shared<popef::Reader>();
reader->parseStream(temp_stream);
return popef::ModelBuilder(reader).createModel();
}
} // namespace ipu
} // namespace platform
} // namespace paddle
......@@ -15,9 +15,14 @@ limitations under the License. */
#pragma once
#include <popart/ndarraywrapper.hpp>
#include <popart/session.hpp>
#include <popart/tensordata.hpp>
#include <popart/tensorinfo.hpp>
#include <popart/vendored/any.hpp>
#include <popart/vendored/optional.hpp>
#include <popef/Model.hpp>
#include <popef/Reader.hpp>
#include <popef/Types.hpp>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -85,6 +90,10 @@ enum ONNXDataType : int {
BFLOAT16 = 16
};
// TODO(czr): remove const qualifier on return value.
const VarType::Type PopefDType2VarType(const popef::DataType type);
const phi::DataType PopefDtype2PhiDtype(const popef::DataType type);
// VarType::Type to popart::DataType
const popart::DataType VarType2PopartDType(const VarType::Type type);
// phi::DataType to popart::DataType
......@@ -102,6 +111,10 @@ const bool GetBoolEnv(const std::string& str);
// Request number of ipus must be pow(2, n)
const int RequestIpus(const int num_ipus);
// Convert popart session to popef
std::shared_ptr<popef::Model> PopartSessionToPopefModel(
popart::Session* session);
} // namespace ipu
} // namespace platform
} // namespace paddle
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
import paddle
from op_test_ipu import IPUOpTest
class SimpleLayer(paddle.nn.Layer):
def __init__(self):
super(SimpleLayer, self).__init__()
self.conv = paddle.nn.Conv2D(
in_channels=3, out_channels=1, kernel_size=2, stride=1
)
def forward(self, x, target=None):
x = self.conv(x)
x = paddle.fluid.layers.flatten(x, axis=1)
if target is not None:
x = paddle.fluid.layers.softmax(x)
loss = paddle.fluid.layers.cross_entropy(x, target)
return x, loss
return x
class TestBase(IPUOpTest):
def setUp(self):
self.ipu_model = None
self.set_attrs()
if 'POPLAR_IPUMODEL' in os.environ:
self.ipu_model = os.environ['POPLAR_IPUMODEL']
del os.environ['POPLAR_IPUMODEL']
def set_attrs(self):
self.timeout = 0.0
self.batch_size = 8
def tearDown(self):
if getattr(self, 'ipu_model', None):
os.environ['POPLAR_IPUMODEL'] = self.ipu_model
paddle.framework.core.IpuBackend.get_instance().reset()
def generate_feed(self):
return {
"X": np.random.rand(8, 3, 10, 10).astype(np.float32),
"Y": np.random.randint(0, 10, [8], dtype="int64"),
}
@IPUOpTest.static_graph
def build_model(self):
x = paddle.static.data(
name='X', shape=[self.batch_size, 3, 10, 10], dtype='float32'
)
label = paddle.static.data(
name='Y', shape=[self.batch_size], dtype='int64'
)
model = SimpleLayer()
pred, loss = model(x, label)
self.feed_list = [x.name, label.name]
self.fetch_list = [pred.name, loss.name]
def reset_seeds(self):
np.random.seed(self.SEED)
paddle.seed(self.SEED)
self.main_prog.random_seed = self.SEED
self.startup_prog.random_seed = self.SEED
def _test(self, use_ipu=False):
self.reset_seeds()
place = paddle.IPUPlace() if use_ipu else paddle.CPUPlace()
executor = paddle.static.Executor(place)
executor.run(self.startup_prog)
if use_ipu:
paddle.set_device('ipu')
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(
num_ipus=1,
is_training=False,
micro_batch_size=self.batch_size,
enable_manual_shard=False,
)
ipu_strategy.set_options(
{
'enable_model_runtime_executor': True,
'timeout_ms': self.timeout,
}
)
program = paddle.static.IpuCompiledProgram(
self.main_prog, ipu_strategy=ipu_strategy
).compile(self.feed_list, self.fetch_list)
else:
program = self.main_prog
epochs = 10
preds = []
losses = []
for epoch in range(epochs):
feed = self.generate_feed()
dy_batch = feed["X"].shape[0]
if not use_ipu:
# padding inputs
pad_batch = self.batch_size - dy_batch
for k, v in feed.items():
pad_size = tuple(
(
(0, 0 if i != 0 else pad_batch)
for i in range(len(v.shape))
)
)
feed[k] = np.pad(v, pad_size, 'constant', constant_values=0)
pred, loss = executor.run(
program, feed=feed, fetch_list=self.fetch_list
)
if not use_ipu:
pred = pred[0:dy_batch]
loss = loss[0:dy_batch]
preds.append(pred)
losses.append(loss)
return np.concatenate(preds, axis=0), np.concatenate(losses, axis=0)
def test_infer(self):
self.build_model()
ipu_pred, ipu_loss = self._test(True)
cpu_pred, cpu_loss = self._test(False)
np.testing.assert_allclose(
ipu_pred.flatten(), cpu_pred.flatten(), rtol=1e-05, atol=1e-4
)
np.testing.assert_allclose(
ipu_loss.flatten(), cpu_loss.flatten(), rtol=1e-05, atol=1e-4
)
class TestAutoBatch(TestBase):
def set_attrs(self):
self.timeout = 0.01
# fixed batch
self.batch_size = 8
def generate_feed(self):
# generate dynamic batch
batch = np.random.randint(1, self.batch_size)
return {
"X": np.random.rand(batch, 3, 10, 10).astype(np.float32),
"Y": np.random.randint(0, 10, [batch], dtype="int64"),
}
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册