提交 c27f6782 编写于 作者: M Megvii Engine Team

feat(src/opr): add api of training of cpp and related test

GitOrigin-RevId: befb85fd43c02167d685d1e5e5657f8f5b94fb7f
上级 6bb54099
/**
* \file src/opr/impl/training/dataview.cpp
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#include "megbrain/opr/training/dataview.h"
#include "megbrain/exception.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/tensor.h"
#include <random>
namespace mgb {
DataLoader::DataLoader(
std::shared_ptr<IDataView> dataview, mgb::CompNode comp_node,
unsigned long batchsize, bool shuffle, bool drop_last)
: m_dataview(dataview),
m_comp_node(comp_node),
m_batchsize(batchsize),
m_shuffle(shuffle),
m_drop_last(drop_last),
m_idx(0) {
if (!m_comp_node.valid()) {
m_comp_node = CompNode::load("xpu0");
}
for (size_t i = 0; i < m_dataview->size(); i++) {
m_index_collection.push_back(i);
}
if (m_dataview->size() > 0) {
auto data_sample = m_dataview->get_item(0);
SmallVector<size_t> dshape;
dshape.push_back(static_cast<size_t>(batchsize));
for (size_t i = 0; i < data_sample.first->layout().ndim; i++) {
dshape.push_back(data_sample.first->shape()[i]);
}
m_data_shape = dshape;
SmallVector<size_t> lshape;
lshape.push_back(m_batchsize);
for (size_t i = 1; i < data_sample.second->layout().ndim; i++) {
lshape.push_back(data_sample.second->shape()[i]);
}
m_label_shape = lshape;
m_data_type = data_sample.first->dtype();
m_label_type = data_sample.second->dtype();
} else {
mgb_throw(AssertionError, "The dataset is empty.");
}
}
size_t DataLoader::size() {
return m_dataview->size() / m_batchsize;
}
DataPair DataLoader::next() {
if (m_idx == 0 && m_shuffle) {
std::shuffle(
m_index_collection.begin(), m_index_collection.end(),
std::default_random_engine());
}
if (m_idx >= m_index_collection.size() - m_batchsize) {
m_idx = 0;
}
auto data = std::make_shared<HostTensorND>(m_comp_node, m_data_shape, m_data_type);
auto label =
std::make_shared<HostTensorND>(m_comp_node, m_label_shape, m_label_type);
size_t data_bytes = m_dataview->get_item(m_index_collection.at(m_idx))
.first->layout()
.access_bytes();
size_t label_bytes = m_dataview->get_item(m_index_collection.at(m_idx))
.second->layout()
.access_bytes();
auto data_ptr = data->raw_ptr();
auto label_ptr = label->raw_ptr();
for (unsigned int i = 0; i < m_batchsize; i++) {
auto item = m_dataview->get_item(m_index_collection.at(m_idx));
auto pre_data = item.first;
auto pre_label = item.second;
auto pre_data_ptr = pre_data->raw_ptr();
auto pre_label_ptr = pre_label->raw_ptr();
memcpy(data_ptr + data_bytes * i, pre_data_ptr,
sizeof(megdnn::dt_byte) * data_bytes);
memcpy(label_ptr + label_bytes * i, pre_label_ptr,
sizeof(megdnn::dt_byte) * label_bytes);
m_idx++;
}
return {data, label};
}
} // namespace mgb
/**
* \file src/opr/impl/training/loss.cpp
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#include "megbrain/opr/training/loss.h"
#include "megbrain/exception.h"
#include "megbrain/opr/indexing.h"
namespace mgb {
namespace loss {
CrossEntropyLoss::CrossEntropyLoss(
bool with_logits, float label_smooth, ReduceMode reduce_mode, int axis)
: m_with_logits(with_logits),
m_label_smooth(label_smooth),
m_reduce_mode(reduce_mode),
m_axis(axis) {}
SymbolVar CrossEntropyLoss::operator()(
mgb::SymbolVar symbol_pred, mgb::SymbolVar symbol_label) {
mgb_assert(
symbol_pred.shape().ndim >= symbol_label.shape().ndim,
"The label must have less dimensions than the pred.");
for (size_t i = 0; i < symbol_label.shape().ndim; i++) {
mgb_assert(
symbol_pred.shape()[i] == symbol_label.shape()[i] || (int)i == m_axis,
"Unmatched shape for pred and label.");
}
mgb_assert(m_label_smooth >= .0f, "The label_smmoth must be positive value");
SymbolVar symbol_loss;
SymbolVar symbol_middle;
SymbolVar symbol_max = opr::reduce_ax_max(symbol_pred, m_axis);
SymbolVar symbol_primary_item =
opr::IndexingOneHot::make(symbol_pred, symbol_label, {m_axis});
if (m_with_logits) {
symbol_middle = opr::reduce_ax_sum(symbol_pred, m_axis) /
opr::GetVarShape::make(symbol_pred, {m_axis});
SymbolVar symbol_logits =
symbol_max + opr::log(opr::reduce_ax_sum(
opr::exp(symbol_pred - symbol_max), m_axis));
symbol_loss = symbol_logits;
} else {
symbol_middle = opr::reduce_ax_sum(opr::log(symbol_pred), m_axis) /
opr::GetVarShape::make(symbol_pred, {m_axis});
symbol_primary_item = opr::log(symbol_primary_item);
}
if (m_label_smooth > .0f) {
symbol_loss = symbol_loss - m_label_smooth * symbol_middle -
(1 - m_label_smooth) * symbol_primary_item;
} else {
symbol_loss = symbol_loss - symbol_primary_item;
}
if (m_reduce_mode == ReduceMode::MEAN) {
symbol_loss =
opr::reduce_sum(symbol_loss.flatten(), symbol_loss.make_scalar(1)) /
(float)(symbol_loss.shape().total_nr_elems());
} else if (m_reduce_mode == ReduceMode::SUM) {
symbol_loss =
opr::reduce_sum(symbol_loss.flatten(), symbol_loss.make_scalar(1));
}
return symbol_loss;
}
MSELoss::MSELoss(ReduceMode reduce_mode) : m_reduce_mode(reduce_mode){};
mgb::SymbolVar MSELoss::operator()(
mgb::SymbolVar symbol_pred, mgb::SymbolVar symol_label) {
return opr::pow(symbol_pred - symol_label, symbol_pred.make_scalar(2));
}
} // namespace loss
} // namespace mgb
/**
* \file src/opr/impl/training/optimizer.cpp
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#include "megbrain/opr/training/optimizer.h"
#include "megbrain/exception.h"
#include "megbrain/opr/training/utils.h"
namespace mgb {
namespace optimizer {
SymbolVarArray Optimizer::make_multiple(
SymbolVarArray symbol_weights, SymbolVarArray symbol_grads,
std::shared_ptr<mgb::cg::ComputingGraph> graph) {
if (symbol_weights.size() != symbol_grads.size()) {
mgb_throw(AssertionError, "The count of weights differs with that of grads.");
}
SymbolVarArray r;
for (size_t i = 0; i < symbol_weights.size(); i++) {
r.push_back(make(symbol_weights[i], symbol_grads[i], graph));
}
return r;
}
SGD::SGD(float lr, float weight_decay, float momentum)
: m_lr(lr), m_weight_decay(weight_decay), m_momentum(momentum) {
if (m_lr <= 0) {
mgb_throw(AssertionError, "Invalid learning rate: negative value.");
}
if (m_weight_decay < 0) {
mgb_throw(AssertionError, "Invalid weight_decay value: negative value.");
}
if (m_momentum < 0) {
mgb_throw(AssertionError, "Invalid momentum value: negative value.");
}
}
SymbolVar SGD::make(
SymbolVar symbol_weight, SymbolVar symbol_grad,
std::shared_ptr<cg::ComputingGraph> graph) {
SymbolVar symbol_pre_grad;
auto pre_grad = TensorGen::zeros<dtype::Float32>(
symbol_grad.shape(), symbol_grad.node()->comp_node());
m_pre_grads.push_back(pre_grad);
symbol_pre_grad = opr::SharedDeviceTensor::make(*graph, *pre_grad);
if (m_weight_decay != .0f) {
symbol_grad = symbol_grad + m_weight_decay * symbol_weight;
}
if (m_momentum != .0f) {
symbol_pre_grad =
opr::AddUpdate::make(symbol_pre_grad, symbol_grad, {m_momentum, 1.0f});
return opr::AddUpdate::make(symbol_weight, -symbol_pre_grad, {1.f, m_lr});
} else {
return opr::AddUpdate::make(symbol_weight, -symbol_grad, {1.f, m_lr});
}
}
Adam::Adam(
float lr, float weight_decay, std::pair<float, float> betas, float eps,
bool amsgrad)
: m_lr(lr),
m_weight_decay(weight_decay),
m_betas(betas),
m_eps(eps),
m_amsgrad(amsgrad) {
mgb_assert(m_lr > 0, "Invalid learning rate: negative value.");
mgb_assert(m_weight_decay >= 0, "Invalid weight_decay value: negative value.");
mgb_assert(
m_betas.first >= 0 && m_betas.second >= 0 && m_betas.first < 1 &&
m_betas.second < 1,
"Invalid betas value: negative value or larger than 1.");
}
SymbolVar Adam::make(
SymbolVar symbol_weight, SymbolVar symbol_grad,
std::shared_ptr<cg::ComputingGraph> graph) {
CompNode comp_node = symbol_grad.node()->comp_node();
DType dt = symbol_grad.dtype();
m_correction1 = TensorGen::ones<dtype::Float32>({1}, comp_node);
m_correction2 = TensorGen::ones<dtype::Float32>({1}, comp_node);
std::shared_ptr<DeviceTensorND> exp_avg =
std::make_shared<DeviceTensorND>(comp_node, symbol_grad.shape(), dt);
mgb::fill_zero_dev_tensor(*exp_avg);
std::shared_ptr<DeviceTensorND> exp_avg_sq =
std::make_shared<DeviceTensorND>(comp_node, symbol_grad.shape(), dt);
mgb::fill_zero_dev_tensor(*exp_avg_sq);
m_exp_avg.push_back(exp_avg);
m_exp_avg_sq.push_back(exp_avg_sq);
SymbolVar symbol_correction1 =
opr::SharedDeviceTensor::make(*graph, *m_correction1);
SymbolVar symbol_correction2 =
opr::SharedDeviceTensor::make(*graph, *m_correction2);
SymbolVar symbol_exp_avg = opr::SharedDeviceTensor::make(*graph, exp_avg);
SymbolVar symbol_exp_avg_sq = opr::SharedDeviceTensor::make(*graph, exp_avg_sq);
symbol_correction1 = opr::AddUpdate::make(
symbol_correction1, symbol_correction1, {m_betas.first, .0f});
symbol_correction2 = opr::AddUpdate::make(
symbol_correction2, symbol_correction2, {m_betas.second, .0f});
if (m_weight_decay != .0f) {
symbol_grad = symbol_grad + m_weight_decay * symbol_weight;
}
symbol_exp_avg = opr::AddUpdate::make(
symbol_exp_avg, symbol_grad, {m_betas.first, 1.f - m_betas.first});
symbol_exp_avg_sq = opr::AddUpdate::make(
symbol_exp_avg_sq, symbol_grad * symbol_grad,
{m_betas.second, 1.f - m_betas.second});
SymbolVar delta;
if (m_amsgrad) {
std::shared_ptr<DeviceTensorND> max_exp_avg_sq =
std::make_shared<DeviceTensorND>(comp_node, symbol_grad.shape(), dt);
mgb::fill_zero_dev_tensor(*max_exp_avg_sq);
SymbolVar symbol_max_exp_avg_sq =
opr::SharedDeviceTensor::make(*graph, max_exp_avg_sq);
symbol_max_exp_avg_sq = opr::AddUpdate::make(
symbol_exp_avg_sq, opr::max(symbol_max_exp_avg_sq, symbol_exp_avg_sq),
{1.0f, 1.0f});
delta = (symbol_exp_avg / (1.f - symbol_correction1)) /
(opr::powf(symbol_max_exp_avg_sq / (1.f - symbol_correction2), 0.5f) +
m_eps);
} else {
delta = (symbol_exp_avg / (1.f - symbol_correction1)) /
(opr::pow(
symbol_exp_avg_sq / (1.f - symbol_correction2),
symbol_exp_avg.make_scalar(0.5f)) +
m_eps);
}
return opr::AddUpdate::make(symbol_weight, -delta, {1.0f, m_lr});
}
} // namespace optimizer
} // namespace mgb
/**
* \file src/opr/include/training/dataview.h
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#pragma once
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/tensor_manip.h"
#include <type_traits>
namespace mgb {
using DataPair = std::pair<
std::shared_ptr<mgb::HostTensorND>, std::shared_ptr<mgb::HostTensorND>>;
//! The interface of the dataset.
class IDataView {
public:
/*!
* The method to get an item in dataset with index.
*/
virtual DataPair get_item(int idx) = 0;
/*!
* The method to get the size of the dataset.
*/
virtual size_t size() = 0;
virtual ~IDataView() = default;
};
//! The definition of dataloader, which is corresponding to the <DataLoader> of
//! Python API of MegEngine.
class DataLoader {
public:
DataLoader(
std::shared_ptr<IDataView> dataview, mgb::CompNode compnode,
unsigned long batchsize = 1U, bool shuffle = false, bool drop_last = true);
/*!
* Get the next pair of data of the dataset.
*/
DataPair next();
/*!
* Get the size of the dataloader.
*/
size_t size();
private:
std::shared_ptr<IDataView> m_dataview;
mgb::CompNode m_comp_node;
unsigned long m_batchsize;
bool m_shuffle;
bool m_drop_last;
size_t m_idx;
mgb::TensorShape m_data_shape;
mgb::TensorShape m_label_shape;
mgb::DType m_data_type;
mgb::DType m_label_type;
// Only used in the temp solution for shuffle
std::vector<int> m_index_collection;
};
} // namespace mgb
/**
* \file src/opr/include/training/loss.h
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#pragma once
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/tensor.h"
namespace mgb {
namespace loss {
//! The interface of losses which should be inherited by each loss class.
class ILoss {
public:
/*!
* The reduce mode of loss to convert output to scalar.
*/
enum ReduceMode { SUM = 0, MEAN = 1 };
/*!
* The calculation of the loss, in which the output is a scalar symbolvar
*/
virtual mgb::SymbolVar operator()(
mgb::SymbolVar symbol_pred, mgb::SymbolVar symol_label) = 0;
virtual ~ILoss() = default;
};
/*!
* The cross entropy loss. The definition could be found here:
* https://en.wikipedia.org/wiki/Cross_entropy
*
* It's corresponding to the <CrossEntropy> of Python API of MegEngine.
*/
class CrossEntropyLoss : public ILoss {
public:
CrossEntropyLoss(
bool with_logits = true, float label_smooth = .0f,
ReduceMode reduce_mode = ReduceMode::MEAN, int axis = 1);
mgb::SymbolVar operator()(mgb::SymbolVar symbol_pred, mgb::SymbolVar symol_label);
protected:
bool m_with_logits;
float m_label_smooth;
ReduceMode m_reduce_mode;
int m_axis;
};
/*!
* The MSE(Mean Square Error) loss. The definition could be found here:
* https://en.wikipedia.org/wiki/Mean_squared_error
*
* It's corresponding to the <MSE> of Python API of MegEngine.
*/
class MSELoss : public ILoss {
public:
MSELoss(ReduceMode reduce_mode = ReduceMode::MEAN);
mgb::SymbolVar operator()(mgb::SymbolVar symbol_pred, mgb::SymbolVar symol_label);
protected:
ReduceMode m_reduce_mode;
};
} // namespace loss
} // namespace mgb
/**
* \file src/opr/include/training/optimizer.h
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#pragma once
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/tensor.h"
namespace mgb {
namespace optimizer {
//! The interface of optimizers which should be inherited by each optimizer.
class IOptimizer {
public:
/*!
* The method to add manipulations to the graph to update the weight when the
* input is SymbolvarArrays.
*/
virtual mgb::SymbolVarArray make_multiple(
mgb::SymbolVarArray symbol_weights, mgb::SymbolVarArray symbol_grads,
std::shared_ptr<mgb::cg::ComputingGraph> graph) = 0;
/*!
* The method to add manipulations to the graph to update the weight with a
* certain strategy.
* The output is expected to be the symbolvar after updating the weight.
*/
virtual mgb::SymbolVar make(
mgb::SymbolVar symbol_weight, mgb::SymbolVar symbol_grad,
std::shared_ptr<mgb::cg::ComputingGraph> graph) = 0;
virtual ~IOptimizer() = default;
};
/*!
* An abstract class which helps to simplify the implemention of optimizers.
* It gives a default implemention of method <make_multiple> based on the method
* <make> defined by its derived class.
*/
class Optimizer : public IOptimizer {
public:
mgb::SymbolVarArray make_multiple(
mgb::SymbolVarArray symbol_weights, mgb::SymbolVarArray symbol_grads,
std::shared_ptr<mgb::cg::ComputingGraph> graph);
virtual mgb::SymbolVar make(
mgb::SymbolVar symbol_weight, mgb::SymbolVar symbol_grad,
std::shared_ptr<mgb::cg::ComputingGraph> graph) = 0;
virtual ~Optimizer() = default;
};
/*!
* The SGD(Stochastic gradient descent) optimizer.
* The definition could be found here:
* https://en.wikipedia.org/wiki/Stochastic_gradient_descent
* It is corresponding to the <SGD> of Python API of MegEngine.
*/
class SGD : public Optimizer {
public:
SGD() = default;
SGD(float lr, float weight_decay = .0f, float momentum = .0f);
SGD(const SGD& that) {
m_lr = that.m_lr;
m_momentum = that.m_momentum;
m_weight_decay = that.m_weight_decay;
}
mgb::SymbolVar make(
mgb::SymbolVar symbol_weight, mgb::SymbolVar symbol_grad,
std::shared_ptr<mgb::cg::ComputingGraph> graph);
const SGD& operator=(const SGD& that) {
m_lr = that.m_lr;
m_momentum = that.m_momentum;
m_weight_decay = that.m_weight_decay;
return *this;
}
protected:
float m_lr;
float m_weight_decay;
float m_momentum;
std::vector<std::shared_ptr<mgb::HostTensorND>> m_pre_grads;
};
/*!
* The Adam optimizer. The definition could be found here:
* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#:~:text=full%2Dbatches.%5B26%5D-,Adam,-%5Bedit%5D
* It is corresponding to the <Adam> of Python API of MegEngine.
*/
class Adam : public Optimizer {
public:
Adam() = default;
Adam(float lr, float weight_decay = .0f,
std::pair<float, float> betas = {0.9f, 0.999f}, float eps = 1e-8f,
bool amsgrad = false);
Adam(const Adam& that) {
m_lr = that.m_lr;
m_betas = that.m_betas;
m_eps = that.m_eps;
m_weight_decay = that.m_weight_decay;
m_amsgrad = that.m_amsgrad;
}
mgb::SymbolVar make(
mgb::SymbolVar symbol_weight, mgb::SymbolVar symbol_grad,
std::shared_ptr<mgb::cg::ComputingGraph> graph);
const Adam& operator=(const Adam& that) {
m_lr = that.m_lr;
m_betas = that.m_betas;
m_eps = that.m_eps;
m_weight_decay = that.m_weight_decay;
m_amsgrad = that.m_amsgrad;
return *this;
}
protected:
float m_lr;
float m_weight_decay;
std::pair<float, float> m_betas;
float m_eps;
bool m_amsgrad;
std::vector<std::shared_ptr<mgb::DeviceTensorND>> m_exp_avg;
std::vector<std::shared_ptr<mgb::DeviceTensorND>> m_exp_avg_sq;
std::vector<std::shared_ptr<mgb::DeviceTensorND>> m_max_exp_avg_sq;
std::shared_ptr<mgb::HostTensorND> m_correction1;
std::shared_ptr<mgb::HostTensorND> m_correction2;
};
} // namespace optimizer
} // namespace mgb
\ No newline at end of file
/**
* \file src/opr/include/tensor_gen.h
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#pragma once
#include "megbrain/tensor.h"
namespace mgb {
/*!
* A static class including methods to generate host tensors.
*/
class TensorGen {
public:
/*!
* \brief Generate a tensor with all the elements equal to the given value
*/
template <typename ctype, typename = typename mgb::ctype_enable_if<ctype>::type>
static std::shared_ptr<mgb::HostTensorND> constant(
mgb::TensorShape shape, ctype value,
mgb::CompNode comp_node = mgb::CompNode::load("xpu0")) {
std::shared_ptr<mgb::HostTensorND> r = std::make_shared<mgb::HostTensorND>(
comp_node, shape, typename mgb::DTypeTrait<ctype>::dtype());
auto ptr = r->ptr<ctype>();
for (size_t i = 0, it = r->layout().total_nr_elems(); i < it; i++) {
ptr[i] = value;
}
return r;
}
/*!
* \brief Generate a tensor with all the elements equal to 0
*/
template <typename T>
static std::shared_ptr<mgb::HostTensorND> zeros(
mgb::TensorShape shape,
mgb::CompNode comp_node = mgb::CompNode::load("xpu0")) {
static_assert(
std::is_base_of<mgb::DType, T>(),
"Please use the dtype in namespace mgb or use "
"Tensor::constant.");
using ctype = typename mgb::DTypeTrait<T>::ctype;
return constant(shape, (ctype)0, comp_node);
}
/*!
* \brief Generate a tensor with all the elements equal to 0. In this method
* typename is not required.
*/
static std::shared_ptr<mgb::HostTensorND> zeros(
mgb::TensorShape shape, mgb::DType dtype = mgb::dtype::Float32(),
mgb::CompNode comp_node = mgb::CompNode::load("xpu0")) {
std::shared_ptr<mgb::HostTensorND> r =
std::make_shared<mgb::HostTensorND>(comp_node, shape, dtype);
auto ptr = r->raw_ptr();
memset(ptr, 0, sizeof(megdnn::dt_byte));
return r;
}
/*!
* \brief Generate a tensor with all the elements equal to 1
*/
template <typename T>
static std::shared_ptr<mgb::HostTensorND> ones(
mgb::TensorShape shape,
mgb::CompNode comp_node = mgb::CompNode::load("xpu0")) {
static_assert(
std::is_base_of<mgb::DType, T>(),
"Please use the dtype in namespace mgb or use "
"Tensor::constant.");
using ctype = typename mgb::DTypeTrait<T>::ctype;
return constant(shape, (ctype)1, comp_node);
}
};
} // namespace mgb
/**
* \file src/opr/test/training/loss.cpp
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/tensor.h"
#include "megbrain/test/helper.h"
#include "megbrain/opr/training/loss.h"
using namespace mgb;
using namespace loss;
namespace {
class Device2HostCallback {
public:
Device2HostCallback(std::shared_ptr<HostTensorND> host) : m_host{host} {}
void operator()(const DeviceTensorND& device) { m_host->copy_from(device).sync(); }
private:
std::shared_ptr<HostTensorND> m_host;
};
class CrossEntropyTest : public ::testing::Test {
private:
/* data */
std::shared_ptr<HostTensorND> pred, label, truth, loss;
TensorShape pred_shape = {2, 10};
TensorShape label_shape = {2};
TensorShape truth_shape = {1};
std::vector<float> pred_values = {
-0.22847f, -0.65020f, -0.42470f, 1.32903f, -0.58377f, -0.15881f, -0.23134f,
-0.36147f, -1.05848f, -0.23285f, 0.32360f, -0.36430f, -0.03172f, 1.18970f,
-0.23465f, -0.16139f, -0.22942f, -0.22538f, -0.68029f, -0.41004f};
std::vector<int> label_values = {5, 3};
std::vector<float> truth_values = {1.8120441};
CompNode node = CompNode::load("cpu0");
std::shared_ptr<cg::ComputingGraph> graph;
CrossEntropyLoss cross_entropy_loss;
public:
std::unique_ptr<cg::AsyncExecutable> func;
void setup();
void build_model(float label_smooth = .0f);
void verify();
template <typename T>
void assign_value(std::shared_ptr<HostTensorND> tensor, std::vector<T> value);
};
} // namespace
void CrossEntropyTest::setup() {
pred = std::make_shared<HostTensorND>(node, pred_shape, dtype::Float32());
label = std::make_shared<HostTensorND>(node, label_shape, dtype::Int32());
truth = std::make_shared<HostTensorND>(node, truth_shape, dtype::Float32());
loss = std::make_shared<HostTensorND>(node, truth_shape, dtype::Float32());
assign_value<float>(pred, pred_values);
assign_value<int>(label, label_values);
assign_value<float>(truth, truth_values);
}
template <typename T>
void CrossEntropyTest::assign_value(
std::shared_ptr<HostTensorND> tensor, std::vector<T> values) {
ASSERT_EQ(values.size(), tensor->shape().total_nr_elems());
auto ptr = tensor->ptr<T>();
for (size_t i = 0, it = tensor->shape().total_nr_elems(); i < it; i += 1) {
ptr[i] = values.at(i);
}
}
void CrossEntropyTest::build_model(float label_smooth) {
graph = cg::ComputingGraph::make();
SymbolVar symbol_pred = opr::SharedDeviceTensor::make(*graph, *pred);
SymbolVar symbol_label = opr::SharedDeviceTensor::make(*graph, *label);
SymbolVar symbol_loss = cross_entropy_loss(symbol_pred, symbol_label);
cg::ComputingGraph::OutputSpec spec;
spec.push_back({symbol_loss, Device2HostCallback(loss)});
func = graph->compile(spec);
}
void CrossEntropyTest::verify() {
func->execute().wait();
ASSERT_NEAR(loss->ptr<float>()[0], truth->ptr<float>()[0], 0.001f);
}
TEST_F(CrossEntropyTest, CrossEntropy) {
setup();
build_model();
verify();
}
\ No newline at end of file
/**
* \file src/opr/test/training/optimizer.cpp
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/tensor.h"
#include "megbrain/test/helper.h"
#include "megbrain/opr/training/optimizer.h"
#include "megbrain/opr/training/utils.h"
using namespace mgb;
using namespace optimizer;
namespace {
class Device2HostCallback {
public:
Device2HostCallback(std::shared_ptr<HostTensorND> host) : m_host{host} {}
void operator()(const DeviceTensorND& device) { m_host->copy_from(device).sync(); }
private:
std::shared_ptr<HostTensorND> m_host;
};
template <typename T>
void assign_value(std::shared_ptr<HostTensorND>& tensor, std::vector<T>& values) {
ASSERT_EQ(values.size(), tensor->layout().total_nr_elems());
auto ptr = tensor->ptr<T>();
for (size_t i = 0, it = tensor->layout().total_nr_elems(); i < it; i += 1) {
ptr[i] = values.at(i);
}
}
class OptimizerTest : public ::testing::Test {
public:
void verify(
std::shared_ptr<IOptimizer> optimizer, std::shared_ptr<HostTensorND> weight,
std::shared_ptr<HostTensorND> grad, std::shared_ptr<HostTensorND> truth,
int execute_times);
protected:
std::shared_ptr<IOptimizer> optimizer;
std::shared_ptr<cg::ComputingGraph> graph;
};
void OptimizerTest::verify(
std::shared_ptr<IOptimizer> optimizer, std::shared_ptr<HostTensorND> weight,
std::shared_ptr<HostTensorND> grad, std::shared_ptr<HostTensorND> truth,
int execute_times) {
graph = cg::ComputingGraph::make();
SymbolVar symbol_weight = opr::SharedDeviceTensor::make(*graph, *weight);
SymbolVar symbol_grad = opr::SharedDeviceTensor::make(*graph, *grad);
cg::ComputingGraph::OutputSpec spec;
spec.push_back(
{optimizer->make(symbol_weight, symbol_grad, graph),
Device2HostCallback(weight)});
auto func = graph->compile(spec);
for (int i = 0; i < execute_times; i++) {
func->execute();
}
auto weight_ptr = weight->ptr<float>();
auto truth_ptr = truth->ptr<float>();
for (size_t i = 0, it = weight->shape().total_nr_elems(); i < it; i += 1) {
ASSERT_NEAR(weight_ptr[i], truth_ptr[i], 0.001f);
}
}
} // namespace
TEST_F(OptimizerTest, SGD) {
auto weight = TensorGen::constant({1}, 0.30542f);
auto grad = TensorGen::constant({1}, -1.81453f);
auto truth = TensorGen::constant({1}, 1.04673f);
int execute_times = 10;
std::shared_ptr<SGD> sgd = std::make_shared<SGD>(0.01f, 5e-2f, 0.9f);
verify(sgd, weight, grad, truth, execute_times);
}
TEST_F(OptimizerTest, AdamTest) {
auto weight = TensorGen::constant({1}, 1.62957f);
auto grad = TensorGen::constant({1}, 1.02605f);
auto truth = TensorGen::constant({1}, 1.52969f);
int execute_times = 10;
std::shared_ptr<Adam> adam = std::make_shared<Adam>(0.01f, 0.9f);
verify(adam, weight, grad, truth, execute_times);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册