提交 b601603b 编写于 作者: Y yoni

train on device

上级 12109600
......@@ -69,6 +69,7 @@ class MS_API Model {
/// \brief Free MetaGraph in MindSpore Lite Model.
void FreeMetaGraph();
ModelImpl *model_impl() {return model_impl_;}
protected:
ModelImpl *model_impl_ = nullptr;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_SESSION_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_SESSION_H_
#include <vector>
#include <string>
#include <unordered_map>
// #include "include/lite_session.h"
#include "src/lite_session.h"
namespace mindspore {
namespace lite {
class Model;
}
namespace lite::tensor {
class Tensor;
}
namespace session {
class TrainSession : public lite::LiteSession {
public:
TrainSession();
~TrainSession() = default;
int RunGraph(const session::KernelCallBack &before = nullptr,
const session::KernelCallBack &after = nullptr) override;
int CompileGraph(lite::Model *model) override;
virtual void ReplaceOps();
virtual void* ExportToBuf(void* buf, size_t* len) const;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const;
std::vector<tensor::MSTensor *> GetOutputsByName(const std::string &node_name) const;
virtual void train();
bool is_train() { return train_mode_ == true; }
virtual void eval();
bool is_eval() { return train_mode_ == false; }
protected:
bool train_mode_ = false;
lite::Model* model_ = nullptr;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> ext_output_map_;
// private:
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_SESSION_H_
......@@ -13,9 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/activation_grad.h"
int ReluGrad(float *src0, float *src1, int length, float *dst) {
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/fp32/arithmetic.h"
#include "nnacl/fp32_grad/activation_grad.h"
#include "nnacl/errorcode.h"
inline int ReluGrad(float *src0, float *src1, int length, float *dst) {
for (int i = 0; i < length; ++i) {
dst[i] = src1[i] > 0 ? 1.0f : 0.0f;
}
......
......@@ -13,11 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string.h>
#include <math.h>
#include <string.h>
#include "nnacl/fp32_grad/batch_norm.h"
static void sumSpatialBatch(const float *in, int size, int ch, float *out) {
void sumSpatialBatch(const float *in, int size, int ch, float *out) {
memset(out, 0, ch * sizeof(float));
for (int i = 0; i < size; i++) {
const float *ptr = in + i * ch;
......@@ -32,49 +32,53 @@ void scaleBias(const float *scales, int batch, int n, int size, float *output) {
for (int c = 0; c < n; c++) output[i * n + c] *= scales[c];
}
void normalize(const float *x, const float *mean, const float *variance, float eps, int batch, int filters, int spatial,
void normalize(const float *x, const float *mean, const float *invar, int batch, int filters, int spatial,
float *out) {
int b, f, i;
for (b = 0; b < batch; ++b) {
for (i = 0; i < spatial; ++i) {
for (f = 0; f < filters; ++f) {
int index = b * filters * spatial + i * filters + f;
out[index] = (x[index] - mean[f]) / (sqrt(variance[f]) + eps);
out[index] = (x[index] - mean[f]) * invar[f];
}
}
}
}
void backwardScale(const float *x_norm, const float *delta, int batch, int n, int size, float *scale_updates) {
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch,
int n, int size, float *scale_updates) {
int i, b, f;
memset(scale_updates, 0, n * sizeof(float));
for (b = 0; b < batch; ++b) {
for (i = 0; i < size; ++i) {
for (f = 0; f < n; ++f) {
int index = (b * size + i) * n + f;
scale_updates[f] += delta[index] * x_norm[index];
float x_norm = (x[index] - mean[f]) * invar[f];
scale_updates[f] += delta[index] * x_norm;
}
}
}
}
void meanVar(const float *in, int batch, int spatial, int ch, float *mean, float *var) {
void meanVar(const float *in, int batch, int spatial, int ch, float eps, float *mean, float *invar) {
float N = batch * spatial;
sumSpatialBatch(in, N, ch, mean);
for (int f = 0; f < ch; ++f) mean[f] /= N;
memset(var, 0, ch * sizeof(float));
for (int i = 0; i < N; i++) {
for (int f = 0; f < ch; f++) {
float x = in[i * ch + f];
var[f] += (x - mean[f]) * (x - mean[f]);
for (int f = 0; f < ch; ++f) {
mean[f] /= N;
}
for (int f=0; f< ch; f++) {
float tvar = 0;
for (int i =0; i< N; i++) {
float x = in[i*ch +f];
tvar += (x-mean[f]) *(x-mean[f]);
}
invar[f] = 1.0f/(sqrt(tvar/N+eps));
}
for (int f = 0; f < ch; f++) var[f] /= N;
}
void meanDelta(float *yt, int size, int ch, float eps, float *variance, float *mean_delta) {
void meanDelta(float *yt, int size, int ch, float *invar, float *mean_delta) {
sumSpatialBatch(yt, size, ch, mean_delta);
for (int i = 0; i < ch; i++) mean_delta[i] *= -1.f / sqrt((variance[i] + eps));
for (int i = 0; i < ch; i++) mean_delta[i] *= -invar[i];
}
void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial,
......@@ -93,8 +97,8 @@ void meanAdd(const float *x, const float *mean, const float *variance_delta, int
}
}
void varianceDelta(const float *x, const float *delta, const float *mean, const float *variance, int batch, int filters,
int spatial, float eps, float *variance_delta) {
void varianceDelta(const float *x, const float *delta, const float *mean, const float *invar, int batch, int filters,
int spatial, float *variance_delta) {
int i, k;
memset(variance_delta, 0, filters * sizeof(float));
for (k = 0; k < batch * spatial; k++) {
......@@ -103,16 +107,16 @@ void varianceDelta(const float *x, const float *delta, const float *mean, const
variance_delta[i] += delta[index] * (x[index] - mean[i]);
}
}
for (i = 0; i < filters; i++) variance_delta[i] *= -.5 * pow(variance[i] + eps, (-3.f / 2.f));
for (i = 0; i < filters; i++) variance_delta[i] *= -.5 * 1.0f/(invar[i]*invar[i]*invar[i]);
}
void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta,
const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta) {
void NormalizeDelta(const float *x, const float *mean, const float *invar, const float *mean_delta,
const float *variance_delta, int batch, int filters, int spatial, float *delta) {
int f, k;
for (k = 0; k < batch * spatial; k++) {
for (f = 0; f < filters; f++) {
int index = k * filters + f;
delta[index] = delta[index] * 1. / (sqrt(variance[f] + eps)) +
delta[index] = delta[index] * invar[f] +
variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) +
mean_delta[f] / (spatial * batch);
}
......
......@@ -17,28 +17,33 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_BATCH_NORM_H_
#define MINDSPORE_LITE_NNACL_FP32_BATCH_NORM_H_
typedef struct bnParameter {
int batch;
int channels;
int spatial;
float eps;
} bnParameter;
#include "nnacl/op_base.h"
typedef struct BNGradParameter {
OpParameter op_parameter_;
float epsilon_;
float momentum_;
} BNGradParameter;
#ifdef __cplusplus
extern "C" {
#endif
void sumSpatialBatch(const float *in, int size, int ch, float *out);
void scaleBias(const float *scales, int batch, int n, int size, float *output);
void normalize(const float *x, const float *mean, const float *variance, float eps, int batch, int filters, int spatial,
void normalize(const float *x, const float *mean, const float *invar, int batch, int filters, int spatial,
float *out);
void backwardScale(const float *x_norm, const float *delta, int batch, int n, int size, float *scale_updates);
void meanVar(const float *in, int batch, int size, int ch, float *mean, float *var);
void meanDelta(float *yt, int size, int ch, float eps, float *variance, float *mean_delta);
void varianceDelta(const float *x, const float *delta, const float *mean, const float *variance, int batch, int ch,
int spatial, float eps, float *variance_delta);
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch,
int n, int size, float *scale_updates);
void meanVar(const float *in, int batch, int size, int ch, float eps, float *mean, float *invar);
void meanDelta(float *yt, int size, int ch, float *invar, float *mean_delta);
void varianceDelta(const float *x, const float *delta, const float *mean, const float *invar, int batch, int ch,
int spatial, float *variance_delta);
void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial,
float *mean_add, float *mean_delta);
void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta,
const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta);
void NormalizeDelta(const float *x, const float *mean, const float *invar, const float *mean_delta,
const float *variance_delta, int batch, int filters, int spatial, float *delta);
#ifdef __cplusplus
}
#endif
......
......@@ -125,9 +125,9 @@ void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param
}
void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param) {
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_;
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_;
// const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_;
// const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_;
const int stride_h = conv_param->stride_h_;
......
......@@ -13,7 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include <stdint.h>
#include <float.h>
#include "nnacl/fp32_grad/pooling_grad.h"
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) {
......@@ -31,33 +32,37 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
int output_batch = pooling_param->output_batch_;
const float *inPtr = NULL;
for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0;
// for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0;
for (int i = 0; i < in_h * in_w * channel * output_batch; i++) output_ptr[i] = 0.0;
float kk = (float)(win_h * win_w);
for (uint16_t ib = 0; ib < output_batch; ib++) {
float *out;
out = &output_ptr[(ib * output_h * output_w)];
inPtr = (float *)(&input_ptr[(ib * in_h * in_w)]);
// out = &output_ptr[(ib * output_h * output_w)];
out = &output_ptr[(ib * in_h * in_w * channel)];
// inPtr = (float *)(&input_ptr[(ib * in_h * in_w)]);
inPtr = (float *)(&input_ptr[(ib * output_h * output_w * channel)]);
if (1) { // in->layout() == Tensor::nhwc)
// iterate over yt
for (uint16_t yh = 0; yh < in_h; yh++) {
for (uint16_t yw = 0; yw < in_w; yw++) {
for (uint16_t yh = 0; yh < output_h; yh++) {
for (uint16_t yw = 0; yw < output_w; yw++) {
for (uint16_t ic = 0; ic < channel; ic++) {
int idx = (yw + yh * in_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw;
int idx = (yw + yh * output_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw;
float delta = inPtr[idx] / kk;
for (int32_t kh = 0; kh < win_h; kh++) {
int xh = yh * stride_h + kh - pad_h;
if ((xh < 0) || (xh >= output_h)) {
if ((xh < 0) || (xh >= in_h)) {
continue;
}
for (int32_t kw = 0; kw < win_w; kw++) {
int xw = yw * stride_w + kw - pad_w;
if ((xw < 0) || (xw >= output_w)) {
if ((xw < 0) || (xw >= in_w)) {
continue;
}
// out[(ic*output_h*output_w) + (xh*output_w) + xw] += delta;
out[(xw + output_w * xh) * channel + ic] += delta;
// out[(xw + output_w * xh) * channel + ic] += delta;
out[(xw + in_w * xh) * channel + ic] += delta;
}
}
}
......@@ -66,21 +71,22 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
} else { // nchw
for (uint16_t ic = 0; ic < channel; ic++) {
// iterate over yt
for (uint16_t yh = 0; yh < in_h; yh++) {
for (uint16_t yw = 0; yw < in_w; yw++) {
int idx = (ic * in_h * in_w) + (in_w * yh) + yw;
for (uint16_t yh = 0; yh < output_h; yh++) {
for (uint16_t yw = 0; yw < output_w; yw++) {
int idx = (ic * output_h * output_w) + (output_w * yh) + yw;
float delta = inPtr[idx] / kk;
for (int32_t kh = 0; kh < win_h; kh++) {
int xh = yh * stride_h + kh - pad_h;
if ((xh < 0) || (xh >= output_h)) {
if ((xh < 0) || (xh >= in_h)) {
continue;
}
for (int32_t kw = 0; kw < win_w; kw++) {
int xw = yw * stride_w + kw - pad_w;
if ((xw < 0) || (xw >= output_w)) {
if ((xw < 0) || (xw >= in_w)) {
continue;
}
out[(ic * output_h * output_w) + (xh * output_w) + xw] += delta;
// out[(ic * output_h * output_w) + (xh * output_w) + xw] += delta;
out[(ic * in_h * in_w) + (xh * in_w) + xw] += delta;
}
}
}
......@@ -90,7 +96,14 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
}
}
void MaxPoolingGrad(const float *dy, const int *indices, float *output_ptr, PoolingParameter *pooling_param) {
void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr,
PoolingParameter *pooling_param) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
......@@ -98,38 +111,73 @@ void MaxPoolingGrad(const float *dy, const int *indices, float *output_ptr, Pool
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_img_size =
output_h * output_w; // Emir -- in original code this varible is calculated according to input size ??
int ind_img_size = in_h * in_w;
// const int w_pad = (output_w + pad_w + pad_w);
const float *inPtr;
const float *dyPtr;
for (int i = 0; i < in_h * in_w * channel * output_batch; i++) output_ptr[i] = 0.0;
for (uint16_t ib = 0; ib < output_batch; ib++) {
float *out;
out = &output_ptr[(ib * in_h * in_w * channel)];
inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]);
dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]);
for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0;
if (1) { // nhwc
for (uint16_t yh = 0; yh < output_h; yh++) {
for (uint16_t yw = 0; yw < output_w; yw++) {
for (uint16_t ic = 0; ic < channel; ic++) {
int idx = (yw + yh * output_w) * channel + ic;
const float *yt = (const float *)(dy);
const int *pos = (const int *)(indices);
float *out = NULL;
float delta = dyPtr[idx];
float max_val = -FLT_MAX;
int max_idx = 0;
for (int32_t kh = 0; kh < win_h; kh++) {
int xh = yh * stride_h + kh - pad_h;
if ((xh < 0) || (xh >= in_h)) {
continue;
}
for (int32_t kw = 0; kw < win_w; kw++) {
int xw = yw * stride_w + kw - pad_w;
if ((xw < 0) || (xw >= in_w)) {
continue;
}
if (1) { // grads->layout() == Tensor::nhwc)
for (int ib = 0; ib < output_batch; ib++) {
out = &(output_ptr[ib * output_w * output_w * channel]);
for (int ix = 0; ix < ind_img_size; ix++) {
for (int cix = 0; cix < channel; cix++) {
int idx = (*pos) * channel + cix;
out[idx] += *yt;
pos++;
yt++;
if (inPtr[(xw + in_w * xh) * channel + ic] > max_val) {
max_val = inPtr[(xw + in_w * xh) * channel + ic];
max_idx = (xw + in_w * xh) * channel + ic;
}
}
}
out[max_idx] += delta;
}
}
}
}
} else {
for (int ib = 0; ib < output_batch; ib++) {
out = &output_ptr[(ib * out_img_size)];
for (int cix = 0; cix < channel; cix++) {
for (int ix = 0; ix < ind_img_size; ix++) {
int idx = cix * output_h * output_w + *pos; // cord_y*output_w + cord_x;
out[idx] += *yt;
pos++;
yt++;
} else { // nchw
for (uint16_t yh = 0; yh < output_h; yh++) {
for (uint16_t yw = 0; yw < output_w; yw++) {
for (uint16_t ic = 0; ic < channel; ic++) {
int idx = (ic * output_h * output_w) + (output_w * yh) + yw;
float delta = dyPtr[idx];
float max_val = -FLT_MAX;
int max_idx = 0;
for (int32_t kh = 0; kh < win_h; kh++) {
int xh = yh * stride_h + kh - pad_h;
if ((xh < 0) || (xh >= in_h)) {
continue;
}
for (int32_t kw = 0; kw < win_w; kw++) {
int xw = yw * stride_w + kw - pad_w;
if ((xw < 0) || (xw >= in_w)) {
continue;
}
if (inPtr[(ic * in_h * in_w) + (xh * in_w) + xw] > max_val) {
max_val = inPtr[(ic * in_h * in_w) + (xh * in_w) + xw];
max_idx = (ic * in_h * in_w) + (xh * in_w) + xw;
}
}
}
out[max_idx] += delta;
}
}
}
}
......
......@@ -23,7 +23,9 @@
extern "C" {
#endif
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param);
void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param);
// void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param);
void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr,
PoolingParameter *pooling_param);
#ifdef __cplusplus
}
#endif
......
......@@ -13,10 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string.h>
#include "nnacl/fp32_grad/reduce_grad.h"
static inline bool NextIndex(const int num_dims, const int *dims, int *current) {
static inline int NextIndex(const int num_dims, const int *dims, int *current) {
int carry = 1;
for (int idx = num_dims - 1; idx >= 0; --idx) {
int current_val = current[idx] + carry;
......@@ -45,10 +45,10 @@ static inline size_t GetOutputOffset(const int num_dims, const int *dims, const
size_t offset = 0;
for (int idx = 0; idx < num_dims; ++idx) {
// if we need to skip this axis
bool is_axis = false;
int is_axis = 0;
for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
if (idx == axes[axis_idx]) {
is_axis = true;
is_axis = 1;
break;
}
}
......@@ -101,10 +101,10 @@ float ReduceMeanAll(const float *src, int size) {
void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims) {
int num_outputs = 1;
int same_shape = true;
int same_shape = 1;
for (int idx = 0; idx < num_dims; ++idx) {
num_outputs *= output_dims[idx];
if (output_dims[idx] != input_dims[idx]) same_shape = false;
if (output_dims[idx] != input_dims[idx]) same_shape = 0;
}
if (same_shape) {
memcpy(output, input, num_outputs * sizeof(float));
......
......@@ -17,8 +17,7 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_REDUCE_GRAD_H_
#define MINDSPORE_LITE_NNACL_FP32_REDUCE_GRAD_H_
#include <cstddef.h>
#include <algorithm.h>
#include <stddef.h>
#ifdef __cplusplus
extern "C" {
......
......@@ -20,7 +20,7 @@
#include "nnacl/op_base.h"
typedef struct SoftmaxCrossEntropyParameter {
OpParameter op_parameter;
OpParameter op_parameter_;
int32_t batch_size_;
unsigned int number_of_classes_;
int n_dim_;
......
......@@ -178,8 +178,8 @@ union PrimitiveType {
Conv2DGradFilter,
Conv2DGradInput,
PoolingGrad,
BNGradInput,
OptMomentum,
BNGrad,
ApplyMomentum,
BiasGrad,
SoftmaxCrossEntropy,
AddGrad,
......@@ -190,6 +190,7 @@ union PrimitiveType {
ActivationGrad,
PriorBox,
SpaceToBatchND,
Depend,
Return,
MakeTuple,
ToFormat,
......
......@@ -149,7 +149,8 @@ table Activation {
alpha: float = 0.2;
}
table ActivationGrad {
type: ActivationGradType = 0;
type: ActivationType = 0;
alpha: float = 0.2;
}
......@@ -230,6 +231,9 @@ table SoftmaxCrossEntropy {
axis: [int];
}
table make_tuple {
}
table PoolingGrad {
format: Format = 0;
......@@ -390,10 +394,11 @@ table DeConv2D {
hasBias: bool = false;
activationType: ActivationType = 0;
}
table BNGradInput {
table BNGrad {
eps : float;
channels: int;
momentum: float;
}
table Scale {
axis: int;
}
......@@ -841,7 +846,10 @@ table SquaredDifference {
table TupleGetItem {
}
table OptMomentum {
table ApplyMomentum {
gradientScale: float;
useLocking: bool;
useNesterov: bool;
}
......@@ -884,6 +892,10 @@ table ToFormat {
dstT: int;
}
table Depend {
}
table Return {
}
......
......@@ -27,7 +27,7 @@ set(LITE_SRC
)
if (SUPPORT_GPU)
set(LITE_SRC
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/subgraph_opencl_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc
......@@ -36,6 +36,24 @@ if (SUPPORT_GPU)
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_runtime.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_wrapper.cc
)
endif()
if (SUPPORT_TRAIN)
set(ANF_SRC
${ANF_SRC}
)
set(PASS_SRC)
set(LITE_SRC
${LITE_SRC}
${ANF_SRC}
# ${CMAKE_CURRENT_SOURCE_DIR}/train/ops/train_ops.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
)
endif ()
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc)
......
......@@ -110,6 +110,7 @@ int CompareOutputData(float *output_data, float *correct_data, int data_size) {
}
}
error /= data_size;
if (error > 0.0001) {
printf("has accuracy error!\n");
printf("%f\n", error);
......@@ -118,12 +119,14 @@ int CompareOutputData(float *output_data, float *correct_data, int data_size) {
return 0;
}
void CompareOutput(float *output_data, std::string file_path) {
int CompareOutput(float *output_data, std::string file_path) {
size_t output_size;
auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size));
size_t output_num = output_size / sizeof(float);
printf("output num : %zu\n", output_num);
CompareOutputData(output_data, ground_truth, output_num);
int res = CompareOutputData(output_data, ground_truth, output_num);
delete [] ground_truth;
return res;
}
} // namespace lite
} // namespace mindspore
......@@ -47,7 +47,7 @@ void WriteToTxt(const std::string& file_path, void *data, size_t element_size) {
int WriteToBin(const std::string& file_path, void *data, size_t size);
int CompareOutputData(float *output_data, float *correct_data, int data_size);
void CompareOutput(float *output_data, std::string file_path);
int CompareOutput(float *output_data, std::string file_path);
std::string GetAndroidPackageName();
std::string GetAndroidPackagePath();
......
......@@ -47,7 +47,9 @@ int CompareRelativeOutput(float *output_data, std::string file_path) {
auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size));
size_t output_num = output_size / sizeof(float);
std::cout << "output num : " << output_num << "\n";
return CompareOutputRelativeData(output_data, ground_truth, output_num);
int res = CompareOutputRelativeData(output_data, ground_truth, output_num);
delete [] ground_truth;
return res;
}
} // namespace lite
} // namespace mindspore
......@@ -39,6 +39,10 @@ int Executor::Run(std::vector<tensor::Tensor *> &in_tensors, std::vector<tensor:
}
}
kernel::LiteKernelUtil::InitTensorRefCount(kernels);
for (auto out_tensor : out_tensors) { // increase RefCount of output tensors, such that Run will not free them
out_tensor->SetRefCount(out_tensor->RefCount() + 1);
}
for (auto *kernel : kernels) {
MS_ASSERT(nullptr != kernel);
......@@ -48,6 +52,8 @@ int Executor::Run(std::vector<tensor::Tensor *> &in_tensors, std::vector<tensor:
MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->name();
}
}
// JBDEBUG
// std::cout << "executing kernel " << kernel->name() << "\n";
auto ret = kernel->Run();
if (0 != ret) {
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
......
......@@ -27,7 +27,6 @@
#include "src/ir/tensor.h"
#include "include/errorcode.h"
// using mindspore::kernel::AddressPtr;
namespace mindspore::kernel {
using mindspore::lite::RET_ERROR;
......
......@@ -112,11 +112,11 @@ int ModelImpl::BuildOps() {
Model *Model::Import(const char *model_buf, size_t size) {
auto model = new Model();
model->model_impl_ = ModelImpl::Import(model_buf, size);
if (model_buf == nullptr) {
MS_LOG(ERROR) << "model buf is null";
return nullptr;
}
model->model_impl_ = ModelImpl::Import(model_buf, size);
if (model->model_impl_ == nullptr) {
MS_LOG(ERROR) << "model impl is null";
return nullptr;
......
......@@ -20,11 +20,11 @@ namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ActivationGrad::GetType() const { return this->primitive_->value.AsActivationGrad()->type; }
float ActivationGrad::GetAlpha() const { return this->primitive_->value.AsActivationGrad()->alpha; }
void ActivationGrad::SetType(int type) {
this->primitive_->value.AsActivationGrad()->type = (schema::ActivationGradType)type;
this->primitive_->value.AsActivationGrad()->type = (schema::ActivationType)type;
}
void ActivationGrad::SetAlpha(float alpha) { this->primitive_->value.AsActivationGrad()->alpha = alpha; }
#else
int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
......@@ -40,7 +40,7 @@ int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flat
return RET_OK;
}
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }
float ActivationGrad::GetAlpha() const { return this->primitive_->value_as_ActivationGrad()->alpha(); }
#endif
} // namespace lite
} // namespace mindspore
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
#ifndef MINDSPORE_LITE_SRC_OPS_ACTIVATION_GRAD_H_
#define MINDSPORE_LITE_SRC_OPS_ACTIVATION_GRAD_H_
#include <vector>
#include <set>
......@@ -32,13 +32,15 @@ class ActivationGrad : public PrimitiveC {
ActivationGrad() = default;
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetType(int type);
void SetAlpha(float alpha);
#else
ActivationGrad() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetType() const;
float GetAlpha() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
#endif // MINDSPORE_LITE_SRC_OPS_ACTIVATION_GRAD_H_
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/apply_momentum.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
#else
int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_ApplyMomentum();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_ApplyMomentum return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateApplyMomentum(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ActivationGrad, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
int ApplyMomentum::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
if (5 != inputs.size()) {
MS_LOG(ERROR) << "ApplyMomentum should have at 5 input tensors";
return RET_ERROR;
}
// if (outputs.empty()) {
// MS_LOG(ERROR) << "ApplyMomentumCPUKernel error input output size!";
// return RET_ERROR;
// }
if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum() || inputs[0]->ElementsNum() != inputs[3]->ElementsNum() ||
inputs[2]->ElementsNum() != 1 || inputs[4]->ElementsNum() != 1) {
MS_LOG(ERROR) << "error input data size!";
return RET_ERROR;
}
if (!outputs.empty()) {
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_APPLY_MOMENTUM_H_
#define MINDSPORE_LITE_SRC_OPS_APPLY_MOMENTUM_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class ApplyMomentum : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC);
ApplyMomentum() = default;
explicit ApplyMomentum(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
ApplyMomentum() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_APPLY_MOMENTUM_H_
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/arithmetic_grad.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
namespace mindspore {
namespace lite {
int ArithmeticGrad::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {
if (inputs_.size() != 3) {
MS_LOG(ERROR) << "The number of input must be 3";
return RET_ERROR;
}
if (outputs_.size() != 2) {
MS_LOG(ERROR) << "The number of output must be 2";
return RET_ERROR;
}
auto dy = inputs_[0];
auto x1 = inputs_[1];
auto x2 = inputs_[2];
auto dx1 = outputs_[0];
auto dx2 = outputs_[1];
MS_ASSERT(dy != nullptr);
MS_ASSERT(x1 != nullptr);
MS_ASSERT(x2 != nullptr);
MS_ASSERT(dx1 != nullptr);
MS_ASSERT(dx2 != nullptr);
auto inShape0 = x1->shape();
auto inShape1 = x2->shape();
auto outShape = dy->shape();
if ((Type() == schema::PrimitiveType_AddGrad) || (Type() == schema::PrimitiveType_SubGrad)) {
ndim_ = outShape.size();
auto fillDimNum0 = outShape.size() - inShape0.size();
auto fillDimNum1 = outShape.size() - inShape1.size();
int j0 = 0;
int j1 = 0;
for (unsigned int i = 0; i < outShape.size(); i++) {
x1_shape_[i] = (i < fillDimNum0) ? 1 : inShape0[j0++];
x2_shape_[i] = (i < fillDimNum1) ? 1 : inShape1[j1++];
dy_shape_[i] = outShape[i];
}
} else {
// if (inShape0.size() < inShape1.size())
if (dx1->ElementsNum() < dx2->ElementsNum()) {
ndim_ = inShape1.size();
auto fillDimNum = inShape1.size() - inShape0.size(); // This will not work for batch!
int j = 0;
for (unsigned int i = 0; i < inShape1.size(); i++) {
if (i < fillDimNum) {
x2_shape_[i] = 1;
} else {
x2_shape_[i] = inShape0[j++];
}
x1_shape_[i] = inShape1[i];
dy_shape_[i] = outShape[i];
}
} else if (dx2->ElementsNum() < dx1->ElementsNum()) { // if (inShape0.size() > inShape1.size())
ndim_ = inShape0.size();
broadcasting_ = true;
ndim_ = inShape0.size();
int j = 0;
auto fillDimNum = inShape0.size() - inShape1.size();
for (unsigned int i = 0; i < inShape0.size(); i++) {
if (i < fillDimNum) {
x2_shape_[i] = 1;
} else {
x2_shape_[i] = inShape1[j++];
}
x1_shape_[i] = inShape0[i];
dy_shape_[i] = outShape[i];
}
} else {
broadcasting_ = false;
for (unsigned int i = 0; i < inShape0.size(); i++) {
x2_shape_[i] = inShape1[i];
x1_shape_[i] = inShape0[i];
dy_shape_[i] = outShape[i];
}
}
}
dx1->set_shape(x1->shape());
dx2->set_shape(x2->shape());
dx1->set_data_type(dy->data_type());
dx2->set_data_type(dy->data_type());
return RET_OK;
}
} // namespace lite
} // namespace mindspore
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_ARITHMETIC_GRAD_H_
#define MINDSPORE_LITE_SRC_OPS_ARITHMETIC_GRAD_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class ArithmeticGrad : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ArithmeticGrad, PrimitiveC);
ArithmeticGrad() = default;
explicit ArithmeticGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
// explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
ArithmeticGrad() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
return RET_ERROR;
}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
bool Broadcasting() { return this->broadcasting_; }
int NDims() { return this->ndim_; }
std::vector<int> dyShape() { return this->dy_shape_; }
std::vector<int> x1Shape() { return this->x1_shape_; }
std::vector<int> x2Shape() { return this->x2_shape_; }
protected:
bool broadcasting_ = false;
int ndim_;
std::vector<int> dy_shape_;
std::vector<int> x1_shape_;
std::vector<int> x2_shape_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_ARITHMETIC_GRAD_H_
......@@ -48,6 +48,32 @@ std::vector<int> BiasGrad::GetAxis() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
int BiasGrad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
if (1 != inputs.size()) {
MS_LOG(ERROR) << "BiasGrad should have one input";
return RET_ERROR;
}
if (1 != outputs.size()) {
MS_LOG(ERROR) << "BiasGrad should have one output";
return RET_ERROR;
}
auto *in0 = inputs.front();
auto *out = outputs.front();
MS_ASSERT(in0 != nullptr);
MS_ASSERT(out != nullptr);
auto inshape = in0->shape();
int ndim = inshape.size();
for (int i = 0; i < ndim - 1; i++) {
inshape[i] = 1;
}
out->set_shape(inshape);
out->set_data_type(in0->data_type());
out->SetFormat(in0->GetFormat());
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_
#ifndef MINDSPORE_LITE_SRC_OPS_BIAS_GRAD_H_
#define MINDSPORE_LITE_SRC_OPS_BIAS_GRAD_H_
#include <vector>
#include <set>
......@@ -38,10 +38,11 @@ class BiasGrad : public PrimitiveC {
BiasGrad() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override;
#endif
std::vector<int> GetAxis() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_
#endif // MINDSPORE_LITE_SRC_OPS_BIAS_GRAD_H_
......@@ -14,33 +14,33 @@
* limitations under the License.
*/
#include "src/ops/bn_grad_input.h"
#include "src/ops/bn_grad.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float BNGradInput::GetEps() const { return this->primitive_->value.AsBNGradInput()->eps; }
int BNGradInput::GetChannels() const { return this->primitive_->value.AsBNGradInput()->channels; }
float BNGrad::GetEps() const { return this->primitive_->value.AsBNGrad()->eps; }
float BNGrad::GetMomentum() const { return this->primitive_->value.AsBNGrad()->momentum; }
void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->eps = eps; }
void BNGradInput::SetChannels(int channels) { this->primitive_->value.AsBNGradInput()->channels = channels; }
void BNGrad::SetEps(float eps) { this->primitive_->value.AsBNGrad()->eps = eps; }
void BNGrad::SetMomentum(float momentum) { this->primitive_->value.AsBNGrad()->momentum = momentum; }
#else
int BNGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
int BNGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BNGradInput();
auto attr = primitive->value_as_BNGrad();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_BNGradInput return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateBNGradInput(*fbb, attr->eps(), attr->channels());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGradInput, val_offset.o);
auto val_offset = schema::CreateBNGrad(*fbb, attr->eps(), attr->momentum());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGrad, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); }
int BNGradInput::GetChannels() const { return this->primitive_->value_as_BNGradInput()->channels(); }
float BNGrad::GetEps() const { return this->primitive_->value_as_BNGrad()->eps(); }
float BNGrad::GetMomentum() const { return this->primitive_->value_as_BNGrad()->momentum(); }
#endif
} // namespace lite
......
......@@ -25,21 +25,20 @@
namespace mindspore {
namespace lite {
class BNGradInput : public PrimitiveC {
class BNGrad : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BNGradInput, PrimitiveC);
BNGradInput() = default;
explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
MS_DECLARE_PARENT(BNGrad, PrimitiveC);
BNGrad() = default;
explicit BNGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetEps(float eps);
void SetChannels(int channels);
void SetMomentum(float momentum);
#else
BNGradInput() = default;
BNGrad() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetEps() const;
int GetChannels() const;
float GetMomentum() const;
};
} // namespace lite
} // namespace mindspore
......
......@@ -105,5 +105,47 @@ int Conv2DGradFilter::GetActivationType() const {
}
#endif
int Conv2DGradFilter::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
if (3 != inputs.size()) {
MS_LOG(ERROR) << "Conv2d Grad Filter should have 3 inputs";
return RET_ERROR;
}
if (1 != outputs.size()) {
MS_LOG(ERROR) << "Conv2d Grad Filter should have one output";
return RET_ERROR;
}
auto *in0 = inputs.at(0);
auto *in = inputs.at(2);
MS_ASSERT(out != nullptr);
std::vector<int> output_shape;
int *out_shape = reinterpret_cast<int *>(in->Data());
int new_size = in->ElementsNum();
if (in0->GetFormat() == in->GetFormat()) {
for (int i = 0; i < new_size; i++) output_shape.push_back(out_shape[i]);
} else {
if ((in0->GetFormat() == schema::Format_NHWC) && (in->GetFormat() == schema::Format_NCHW)) {
output_shape.push_back(out_shape[0]);
output_shape.push_back(out_shape[2]);
output_shape.push_back(out_shape[3]);
output_shape.push_back(out_shape[1]);
} else {
MS_LOG(ERROR) << "Shape covnert is not supported";
return RET_ERROR;
}
}
auto *out = outputs.at(0);
MS_ASSERT(out != nullptr);
out->set_shape(output_shape);
out->set_data_type(in0->data_type());
out->SetFormat(in0->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_
#ifndef MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_FILTER_H_
#define MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_FILTER_H_
#include <vector>
#include <set>
......@@ -53,6 +53,7 @@ class Conv2DGradFilter : public PrimitiveC {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetFormat() const;
int GetGroup() const;
int GetChannelIn() const;
......@@ -74,4 +75,4 @@ class Conv2DGradFilter : public PrimitiveC {
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_
#endif // MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_FILTER_H_
......@@ -103,5 +103,46 @@ int Conv2DGradInput::GetActivationType() const {
}
#endif
int Conv2DGradInput::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
if (3 != inputs.size()) {
MS_LOG(ERROR) << "Conv2d Grad Input should have 3 inputs";
return RET_ERROR;
}
if (1 != outputs.size()) {
MS_LOG(ERROR) << "Conv2d Grad input should have one output";
return RET_ERROR;
}
auto *in0 = inputs.at(0);
auto *in = inputs.at(2);
MS_ASSERT(out != nullptr);
std::vector<int> output_shape;
int *out_shape = reinterpret_cast<int *>(in->Data());
int new_size = in->ElementsNum();
if (in0->GetFormat() == in->GetFormat()) {
for (int i = 0; i < new_size; i++) output_shape.push_back(out_shape[i]);
} else {
if ((in0->GetFormat() == schema::Format_NHWC) && (in->GetFormat() == schema::Format_NCHW)) {
output_shape.push_back(out_shape[0]);
output_shape.push_back(out_shape[2]);
output_shape.push_back(out_shape[3]);
output_shape.push_back(out_shape[1]);
} else {
MS_LOG(ERROR) << "Shape covnert is not supported";
return RET_ERROR;
}
}
auto *out = outputs.at(0);
MS_ASSERT(out != nullptr);
out->set_shape(output_shape);
out->set_data_type(in0->data_type());
out->SetFormat(in0->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_
#ifndef MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_INPUT_H_
#define MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_INPUT_H_
#include <vector>
#include <set>
......@@ -53,6 +53,7 @@ class Conv2DGradInput : public PrimitiveC {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetFormat() const;
int GetGroup() const;
int GetChannelIn() const;
......@@ -74,4 +75,4 @@ class Conv2DGradInput : public PrimitiveC {
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_
#endif // MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_INPUT_H_
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_
#define LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_
#ifndef MINDSPORE_LITE_SRC_OPS_DEDEPTHWISE_CONV2D_H_
#define MINDSPORE_LITE_SRC_OPS_DEDEPTHWISE_CONV2D_H_
#include <vector>
#include <set>
......@@ -84,4 +84,4 @@ class DeDepthwiseConv2D : public PrimitiveC {
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_
#endif // MINDSPORE_LITE_SRC_OPS_DEDEPTHWISE_CONV2D_H_
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_
#define LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_
#ifndef MINDSPORE_LITE_SRC_OPS_DEPTHWISE_CONV2D_H_
#define MINDSPORE_LITE_SRC_OPS_DEPTHWISE_CONV2D_H_
#include <vector>
#include <set>
......@@ -94,4 +94,4 @@ class DepthwiseConv2D : public PrimitiveC {
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_
#endif // MINDSPORE_LITE_SRC_OPS_DEPTHWISE_CONV2D_H_
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_
#define LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_
#ifndef MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_
#define MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_
#include <vector>
#include "src/ops/primitive_c.h"
......@@ -37,4 +37,4 @@ class MakeTuple : public PrimitiveC {
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_
#endif // MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_
......@@ -86,5 +86,52 @@ int PoolingGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuf
return RET_OK;
}
#endif
int PoolingGrad::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.at(0);
MS_ASSERT(input != nullptr);
int input_h = input->shape().at(1);
int input_w = input->shape().at(2);
auto window_h = GetWindowH();
auto window_w = GetWindowW();
if (GetGlobal()) {
window_h = input_h;
window_w = input_w;
}
pad_l_ = GetPadLeft();
pad_u_ = GetPadUp();
pad_d_ = GetPadDown();
pad_r_ = GetPadRight();
if (GetPadMode() == schema::PadMode_SAME) {
int output_w = std::ceil(static_cast<float>(input_w) / static_cast<float>(GetStrideW()));
int output_h = std::ceil(static_cast<float>(input_h) / static_cast<float>(GetStrideH()));
auto pad_h_all = ((output_h - 1) * GetStrideH() + (window_h - 1) + 1 - input_h);
auto pad_w_all = ((output_w - 1) * GetStrideW() + (window_w - 1) + 1 - input_w);
if (pad_h_all < 0) {
pad_u_ = pad_d_ = 0;
} else {
pad_u_ = pad_h_all / 2;
pad_d_ = pad_h_all - pad_u_;
}
if (pad_w_all < 0) {
pad_l_ = pad_r_ = 0;
} else {
pad_l_ = pad_w_all / 2;
pad_r_ = pad_w_all - pad_l_;
}
}
auto grad_output = outputs_.at(0);
// todo: fmk type
auto output_shape = input->shape();
grad_output->set_shape(output_shape);
grad_output->set_data_type(input->data_type());
// todo: temp fix
grad_output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_POOLING_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_POOLING_GRAD_H_
#ifndef MINDSPORE_LITE_SRC_OPS_POOLING_GRAD_H_
#define MINDSPORE_LITE_SRC_OPS_POOLING_GRAD_H_
#include <vector>
#include <set>
......@@ -49,6 +49,7 @@ class PoolingGrad : public PrimitiveC {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetFormat() const;
int GetPoolingMode() const;
bool GetGlobal() const;
......@@ -62,8 +63,14 @@ class PoolingGrad : public PrimitiveC {
int GetPadLeft() const;
int GetPadRight() const;
int GetRoundMode() const;
protected:
int pad_u_ = 0;
int pad_d_ = 0;
int pad_l_ = 0;
int pad_r_ = 0;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_POOLING_GRAD_H_
#endif // MINDSPORE_LITE_SRC_OPS_POOLING_GRAD_H_
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_POWER_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_POWER_GRAD_H_
#ifndef MINDSPORE_LITE_SRC_OPS_POWER_GRAD_H_
#define MINDSPORE_LITE_SRC_OPS_POWER_GRAD_H_
#include <vector>
#include <set>
......@@ -46,4 +46,4 @@ class PowerGrad : public PrimitiveC {
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_POWER_GRAD_H_
#endif // MINDSPORE_LITE_SRC_OPS_POWER_GRAD_H_
......@@ -125,6 +125,21 @@
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#endif
#ifdef SUPPORT_TRAIN
#include "src/ops/activation_grad.h"
#include "src/ops/apply_momentum.h"
#include "src/ops/bias_grad.h"
#include "src/ops/pooling_grad.h"
#include "src/ops/conv2d_grad_filter.h"
#include "src/ops/conv2d_grad_input.h"
#include "src/ops/power_grad.h"
#include "src/ops/softmax_cross_entropy.h"
#include "src/ops/bn_grad.h"
#include "src/ops/arithmetic_grad.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
......@@ -353,6 +368,22 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri
return NewPrimitiveC<TupleGetItem>(prim, inputs, quantType);
} else if (op_type == "Softmax") {
return NewPrimitiveC<SoftMax>(prim, inputs, quantType);
#ifdef SUPPORT_TRAIN0
} else if ((op_type == "ReluGrad" || op_type == "Relu6Grad" || op_type == "SigmoidGrad")) {
return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType);
} else if ((op_type == "MaxPoolGrad") || (op_type == "MeanPoolGrad")) {
return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType);
} else if (op_type == "Conv2DBackpropFilter") {
return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType);
} else if (op_type == "Conv2DBackpropInput") {
return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType);
} else if (op_type == "BiasAddGrad") {
return NewPrimitiveC<BiasGrad>(prim, inputs, quantType);
} else if (op_type == "ApplyMomentum") {
return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType);
} else if (op_type == "BatchNormGrad") {
return NewPrimitiveC<BNGrad>(prim, inputs, quantType);
#endif
} else {
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type;
return nullptr;
......@@ -565,6 +596,32 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
return new SparseToDense(primitive);
case schema::PrimitiveType_DetectionPostProcess:
return new DetectionPostProcess(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new ActivationGrad(primitive);
case schema::PrimitiveType_PoolingGrad:
return new PoolingGrad(primitive);
case schema::PrimitiveType_Conv2DGradFilter:
return new Conv2DGradFilter(primitive);
case schema::PrimitiveType_Conv2DGradInput:
return new Conv2DGradInput(primitive);
case schema::PrimitiveType_BiasGrad:
return new BiasGrad(primitive);
case schema::PrimitiveType_ApplyMomentum:
return new ApplyMomentum(primitive);
case schema::PrimitiveType_BNGrad:
return new BNGrad(primitive);
case schema::PrimitiveType_AddGrad:
return new ArithmeticGrad(primitive);
case schema::PrimitiveType_SubGrad:
return new ArithmeticGrad(primitive);
case schema::PrimitiveType_MulGrad:
return new ArithmeticGrad(primitive);
case schema::PrimitiveType_DivGrad:
return new ArithmeticGrad(primitive);
#endif
default:
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitiveT : "
<< schema::EnumNamePrimitiveType(op_type);
......@@ -779,6 +836,31 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(const schema::Primitive *primi
return NewPrimitiveC<SparseToDense>(primitive);
case schema::PrimitiveType_DetectionPostProcess:
return NewPrimitiveC<DetectionPostProcess>(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return NewPrimitiveC<ActivationGrad>(primitive);
case schema::PrimitiveType_PoolingGrad:
return NewPrimitiveC<PoolingGrad>(primitive);
case schema::PrimitiveType_Conv2DGradFilter:
return NewPrimitiveC<Conv2DGradFilter>(primitive);
case schema::PrimitiveType_Conv2DGradInput:
return NewPrimitiveC<Conv2DGradInput>(primitive);
case schema::PrimitiveType_BiasGrad:
return NewPrimitiveC<BiasGrad>(primitive);
case schema::PrimitiveType_ApplyMomentum:
return NewPrimitiveC<ApplyMomentum>(primitive);
case schema::PrimitiveType_BNGrad:
return NewPrimitiveC<BNGrad>(primitive);
case schema::PrimitiveType_AddGrad:
return NewPrimitiveC<ArithmeticGrad>(primitive);
case schema::PrimitiveType_SubGrad:
return NewPrimitiveC<ArithmeticGrad>(primitive);
case schema::PrimitiveType_MulGrad:
return NewPrimitiveC<ArithmeticGrad>(primitive);
case schema::PrimitiveType_DivGrad:
return NewPrimitiveC<ArithmeticGrad>(primitive);
#endif
default:
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitive : "
<< schema::EnumNamePrimitiveType(op_type);
......
......@@ -115,7 +115,7 @@ constexpr size_t kInputSize = 1;
constexpr size_t kOutputSize = 1;
} // namespace
int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
if (inputs_.size() != kInputSize || outputs_.size() != kOutputSize) {
if (inputs_.size() < kInputSize || outputs_.size() != kOutputSize) {
return RET_ERROR;
}
auto input = inputs_.front();
......
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_RESHAPE_H_
#define LITE_MINDSPORE_LITE_C_OPS_RESHAPE_H_
#ifndef MINDSPORE_LITE_SRC_OPS_RESHAPE_H_
#define MINDSPORE_LITE_SRC_OPS_RESHAPE_H_
#include <vector>
#include <set>
......@@ -50,4 +50,4 @@ class Reshape : public PrimitiveC {
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_RESHAPE_H_
#endif // MINDSPORE_LITE_SRC_OPS_RESHAPE_H_
......@@ -51,5 +51,31 @@ int SoftmaxCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive,
return RET_OK;
}
#endif
int SoftmaxCrossEntropy::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
if (1 > outputs.size()) {
MS_LOG(ERROR) << "SoftmaxCrossEntropy should have at least one output";
return RET_ERROR;
}
auto *in0 = inputs.front();
MS_ASSERT(in0 != nullptr);
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
std::vector<int> outshape;
outshape.push_back(1);
out->set_shape(outshape);
out->set_data_type(in0->data_type());
if (1 < outputs.size()) {
auto *grads = outputs.at(1);
MS_ASSERT(grads != nullptr);
grads->set_shape(in0->shape());
grads->set_data_type(in0->data_type());
grads->SetFormat(in0->GetFormat());
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_SOFTMAX_CROSS_ENTROPY_H_
#define LITE_MINDSPORE_LITE_C_OPS_SOFTMAX_CROSS_ENTROPY_H_
#ifndef MINDSPORE_LITE_SRC_OPS_SOFTMAX_CROSS_ENTROPY_H_
#define MINDSPORE_LITE_SRC_OPS_SOFTMAX_CROSS_ENTROPY_H_
#include <vector>
#include <set>
......@@ -39,9 +39,11 @@ class SoftmaxCrossEntropy : public PrimitiveC {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetAxis() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_SOFTMAX_CROSS_ENTROPY_H_
#endif // MINDSPORE_LITE_SRC_OPS_SOFTMAX_CROSS_ENTROPY_H_
......@@ -1678,6 +1678,13 @@ PopulateParameterFunc PopulateParameterRegistry::GetParameterFunc(int type) {
return populate_parameter_funcs_[schema::PrimitiveType(type)];
}
int PopulateParameterRegistry::AddPopulateParameterFunc(const schema::PrimitiveType &type, PopulateParameterFunc func) {
if ((type < schema::PrimitiveType_MIN)|| (type > schema::PrimitiveType_MAX))
return -1;
populate_parameter_funcs_[type] = func;
return 0;
}
OpParameter *PopulateParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
......
......@@ -30,12 +30,16 @@ class PopulateParameterRegistry {
~PopulateParameterRegistry() = default;
static PopulateParameterRegistry *GetInstance();
int AddPopulateParameterFunc(const schema::PrimitiveType &type, PopulateParameterFunc func);
PopulateParameterFunc GetParameterFunc(int type);
protected:
PopulateParameterFunc populate_parameter_funcs_[schema::PrimitiveType_MAX + 1];
};
OpParameter *PopulateActivationParameter(const lite::PrimitiveC *primitive);
OpParameter *PopulateArithmetic(const lite::PrimitiveC *primitive);
OpParameter *PopulateParameter(const mindspore::lite::PrimitiveC *primitive);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_
......@@ -37,8 +37,8 @@ constexpr size_t kOutputNum = 1;
} // namespace
int ReduceBaseCPUKernel::CheckInputsOutputs() {
if (in_tensors_.size() != kInputNum) {
MS_LOG(ERROR) << "Reduce inputs size should be " << kInputNum << " but got " << in_tensors_.size();
if (in_tensors_.size() < kInputNum) {
MS_LOG(ERROR) << "Reduce inputs size should be at least " << kInputNum << " but got " << in_tensors_.size();
return RET_ERROR;
}
if (out_tensors_.size() != kOutputNum) {
......@@ -99,7 +99,15 @@ int ReduceBaseCPUKernel::Init() {
if (reduce_param == nullptr) {
return RET_NULL_PTR;
}
num_axes_ = reduce_param->num_axes_;
if (in_tensors_.size() > 1) {
auto axes_ptr = in_tensors_.at(1);
num_axes_ = axes_ptr->ElementsNum();
memcpy(axes_, axes_ptr->Data(), axes_ptr->Size());
} else {
num_axes_ = reduce_param->num_axes_;
memcpy(axes_, reduce_param->axes_, sizeof(reduce_param->axes_));
}
mode_ = reduce_param->mode_;
memcpy(axes_, reduce_param->axes_, sizeof(reduce_param->axes_));
reduce_to_end_ = reduce_param->reduce_to_end_;
......
......@@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32_grad/activation_grad.h"
#include "nnacl/fp32_grad/activation_grad.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
......@@ -24,41 +25,38 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::ActivationGradType_HSWISH;
using mindspore::schema::ActivationGradType_LEAKY_RELU;
using mindspore::schema::ActivationGradType_RELU;
using mindspore::schema::ActivationGradType_RELU6;
using mindspore::schema::ActivationType_HSWISH;
using mindspore::schema::ActivationType_LEAKY_RELU;
using mindspore::schema::ActivationType_RELU;
using mindspore::schema::ActivationType_RELU6;
using mindspore::schema::PrimitiveType_ActivationGrad;
namespace mindspore::kernel {
int ActivationGradCPUKernel::Init() {
outputs_[0]->set_shape(inputs_[0]->shape());
return RET_OK;
}
int ActivationGradCPUKernel::Init() { return RET_OK; }
int ActivationGradCPUKernel::ReSize() { return RET_OK; }
int ActivationGradCPUKernel::DoActivation(int task_id) {
auto yt_addr = reinterpret_cast<float *>(inputs_.at(0)->Data());
auto input_addr = reinterpret_cast<float *>(inputs_.at(1)->Data());
auto output_addr = reinterpret_cast<float *>(outputs_.at(0)->Data());
auto length = inputs_.at(0)->ElementsNum();
auto yt_addr = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
auto input_addr = reinterpret_cast<float *>(in_tensors_.at(1)->Data());
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
int length = in_tensors_.at(0)->ElementsNum();
auto error_code = RET_OK;
if (type_ == schema::ActivationGradType_RELU) {
if (param_act_grad_->type_ == schema::ActivationType_RELU) {
error_code = ReluGrad(yt_addr, input_addr, length, output_addr);
} else if (type_ == schema::ActivationGradType_RELU6) {
} else if (param_act_grad_->type_ == schema::ActivationType_RELU6) {
error_code = Relu6Grad(yt_addr, input_addr, length, output_addr);
} else if (type_ == schema::ActivationGradType_LEAKY_RELU) {
error_code = LReluGrad(yt_addr, input_addr, length, output_addr, alpha_);
} else if (type_ == schema::ActivationGradType_SIGMOID) {
} else if (param_act_grad_->type_ == schema::ActivationType_LEAKY_RELU) {
error_code = LReluGrad(yt_addr, input_addr, length, output_addr, param_act_grad_->alpha_);
} else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) {
error_code = SigmoidGrad(yt_addr, input_addr, length, output_addr);
} else if (type_ == schema::ActivationGradType_TANH) {
} else if (param_act_grad_->type_ == schema::ActivationType_TANH) {
error_code = TanhGrad(yt_addr, input_addr, length, output_addr);
} else if (type_ == schema::ActivationGradType_HSWISH) {
} else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) {
error_code = HSwishGrad(yt_addr, input_addr, length, output_addr);
} else if (type_ == schema::ActivationGradType_HSIGMOID) {
} else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) {
error_code = HSigmoidGrad(yt_addr, input_addr, length, output_addr);
} else {
MS_LOG(ERROR) << "Activation type error";
......@@ -81,6 +79,12 @@ int ActivationGradRun(void *cdata, int task_id) {
}
int ActivationGradCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return ret;
}
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ActivationGradRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Activation function error error_code[" << error_code << "]";
......
......@@ -20,8 +20,7 @@
#include <vector>
#include "src/lite_kernel.h"
#include "ir/anf.h"
#include "nnacl/activation_grad.h"
#include "nnacl/fp32/activation.h"
namespace mindspore::kernel {
class ActivationGradCPUKernel : public LiteKernel {
......@@ -30,9 +29,7 @@ class ActivationGradCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(param, inputs, outputs, ctx, primitive) {
ActivationGradParameter *param_act_grad = reinterpret_cast<ActivationGradParameter *>(param);
type_ = param_act_grad->type_;
alpha_ = param_act_grad->alpha_;
param_act_grad_ = reinterpret_cast<ActivationParameter *>(param);
}
~ActivationGradCPUKernel() override = default;
......@@ -43,9 +40,9 @@ class ActivationGradCPUKernel : public LiteKernel {
private:
int thread_count_;
int type_;
float alpha_;
ActivationParameter *param_act_grad_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_
......@@ -15,63 +15,81 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32_grad/apply_momentum.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/fp32_grad/opt_momentum.h"
#include "include/errorcode.h"
#include "src/runtime/kernel/arm/fp32/nchw2nhwc.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_OptMomentum;
using mindspore::schema::PrimitiveType_ApplyMomentum;
namespace mindspore::kernel {
int OptMomentumCPUKernel::ReSize() { return 0; }
int ApplyMomentumCPUKernel::ReSize() { return RET_OK; }
int OptMomentumCPUKernel::Run() {
int ApplyMomentumCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
if (inputs_.size() != 5 || !outputs_.empty()) {
MS_LOG(ERROR) << "OptMomentumCPUKernel error input output size!";
return RET_ERROR;
}
if (inputs_[0]->ElementsNum() != inputs_[1]->ElementsNum() ||
inputs_[0]->ElementsNum() != inputs_[3]->ElementsNum()) {
MS_LOG(ERROR) << "error input data size!";
return RET_ERROR;
auto weight = reinterpret_cast<float *>(in_tensors_[0]->Data());
auto accumulate = reinterpret_cast<float *>(in_tensors_[1]->Data());
float learning_rate = reinterpret_cast<float *>(in_tensors_[2]->Data())[0];
auto gradient = reinterpret_cast<float *>(in_tensors_[3]->Data());
float moment = reinterpret_cast<float *>(in_tensors_[4]->Data())[0];
size_t elem_num = in_tensors_[0]->ElementsNum();
// align format
if (in_tensors_[3]->shape().size() == 4 &&
in_tensors_[3]->GetFormat() == schema::Format_NCHW &&
in_tensors_[0]->GetFormat() == schema::Format_KHWC) {
PackNCHWToNHWCFp32(gradient, workspace, in_tensors_[0]->Batch(), in_tensors_[0]->Height() * in_tensors_[0]->Width(),
in_tensors_[0]->Channel());
} else {
memcpy(workspace, gradient, in_tensors_[3]->ElementsNum() * sizeof(float));
}
auto weight = reinterpret_cast<float *>(inputs_[0]->Data());
auto accumulate = reinterpret_cast<float *>(inputs_[1]->Data());
float learning_rate = reinterpret_cast<float *>(inputs_[2]->Data())[0];
auto gradient = reinterpret_cast<float *>(inputs_[3]->Data());
float moment = reinterpret_cast<float *>(inputs_[4]->Data())[0];
size_t elem_num = inputs_[0]->ElementsNum();
for (size_t i = 0; i < elem_num; ++i) {
accumulate[i] = accumulate[i] * moment + gradient[i];
accumulate[i] = accumulate[i] * moment + workspace[i]; // * (1.0 - moment);
weight[i] -= accumulate[i] * learning_rate;
}
return RET_OK;
}
int OptMomentumCPUKernel::Init() { return 0; }
int ApplyMomentumCPUKernel::Init() {
// Only for test with uninitialized Data
size_t elem_num = in_tensors_[0]->ElementsNum();
auto accumulate = reinterpret_cast<float *>(in_tensors_[1]->Data());
for (int i =0; i < elem_num; i++) accumulate[i] = 0.0;
kernel::LiteKernel *CpuOptMomentumFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(desc.type == schema::PrimitiveType_OptMomentum);
auto *kernel = new (std::nothrow) OptMomentumCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new OptMomentumCPUKernel fail!";
workspace = new float[elem_num];
return 0;
}
#if 0
OpParameter *PopulateApplyMomentumParameter(const lite::Primitive *primitive) {
OpParameter *param = new (std::nothrow) OpParameter();
if (param == nullptr) {
MS_LOG(ERROR) << "new Param for OptMomentum failed.";
return nullptr;
}
param->type_ = primitive->Type();
return param;
}
#endif
kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) {
MS_ASSERT(desc.type == schema::PrimitiveType_ApplyMomentum);
auto *kernel = new (std::nothrow) ApplyMomentumCPUKernel(opParameter, inputs, outputs, ctx, primitive);
MS_ASSERT(kernel != nullptr);
auto ret = kernel->Init();
if (0 != ret) {
......@@ -83,5 +101,5 @@ kernel::LiteKernel *CpuOptMomentumFp32KernelCreator(const std::vector<lite::tens
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OptMomentum, CpuOptMomentumFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ApplyMomentum, CpuApplyMomentumFp32KernelCreator)
} // namespace mindspore::kernel
......@@ -14,28 +14,32 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_OPT_MOMENTUM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_OPT_MOMENTUM_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_APPLY_MOMENTUM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_APPLY_MOMENTUM_H_
#include <vector>
#include "src/lite_kernel.h"
#include "ir/anf.h"
namespace mindspore::kernel {
class OptMomentumCPUKernel : public LiteKernel {
class ApplyMomentumCPUKernel : public LiteKernel {
public:
explicit OptMomentumCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
explicit ApplyMomentumCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~OptMomentumCPUKernel() override {}
~ApplyMomentumCPUKernel() override {delete [] workspace;}
int Init() override;
int ReSize() override;
int Run() override;
private:
float *workspace;
};
// OpParameter *PopulateApplyMomentumParameter(const lite::Primitive *primitive);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_OPT_MOMENTUM_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_APPLY_MOMENTUM_H_
......@@ -14,11 +14,11 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "nnacl/fp32_grad/reduce_grad.h"
#include "nnacl/fp32_grad/arithmetic_grad.h"
#include "src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
......@@ -33,108 +33,41 @@ constexpr int kArithGradOpOutputNum = 2;
} // namespace
int ArithmeticGradCPUKernel::Init() {
auto ret = InferShape();
return ret;
}
int ArithmeticGradCPUKernel::InferShape() {
if (inputs_.size() != kArithGradOpInputNum) {
MS_LOG(ERROR) << "The number of input must be " << kArithGradOpInputNum;
return RET_ERROR;
}
if (outputs_.size() != kArithGradOpOutputNum) {
MS_LOG(ERROR) << "The number of output must be " << kArithGradOpOutputNum;
return RET_ERROR;
}
auto dy = inputs_[0];
auto x1 = inputs_[1];
auto x2 = inputs_[2];
auto dx1 = outputs_[0];
auto dx2 = outputs_[1];
auto dx1 = out_tensors_[0];
auto dx2 = out_tensors_[1];
MS_ASSERT(dy != nullptr);
MS_ASSERT(x1 != nullptr);
MS_ASSERT(x2 != nullptr);
MS_ASSERT(dx1 != nullptr);
MS_ASSERT(dx2 != nullptr);
auto inShape0 = x1->shape();
auto inShape1 = x2->shape();
auto outShape = dy->shape();
if ((type() == PrimitiveType_AddGrad) || (type() == PrimitiveType_SubGrad)) {
arithmeticParameter_->ndim_ = outShape.size();
auto fillDimNum0 = outShape.size() - inShape0.size();
auto fillDimNum1 = outShape.size() - inShape1.size();
int j0 = 0;
int j1 = 0;
for (unsigned int i = 0; i < outShape.size(); i++) {
arithmeticParameter_->in_shape0_[i] = (i < fillDimNum0) ? 1 : inShape0[j0++];
arithmeticParameter_->in_shape1_[i] = (i < fillDimNum1) ? 1 : inShape1[j1++];
arithmeticParameter_->out_shape_[i] = outShape[i];
}
} else {
if ((Type() == PrimitiveType_MulGrad) || (Type() == PrimitiveType_DivGrad)) {
// if (inShape0.size() < inShape1.size())
if (dx1->ElementsNum() < dx2->ElementsNum()) {
arithmeticParameter_->ndim_ = inShape1.size();
if (type() == PrimitiveType_MulGrad)
if (Type() == PrimitiveType_MulGrad)
arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul2L;
else if (type() == PrimitiveType_DivGrad)
else if (Type() == PrimitiveType_DivGrad)
arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv2L;
auto fillDimNum = inShape1.size() - inShape0.size(); // This will not work for batch!
int j = 0;
for (unsigned int i = 0; i < inShape1.size(); i++) {
if (i < fillDimNum) {
arithmeticParameter_->in_shape1_[i] = 1;
} else {
arithmeticParameter_->in_shape1_[i] = inShape0[j++];
}
arithmeticParameter_->in_shape0_[i] = inShape1[i];
arithmeticParameter_->out_shape_[i] = outShape[i];
}
} else if (dx2->ElementsNum() < dx1->ElementsNum()) { // if (inShape0.size() > inShape1.size())
arithmeticParameter_->ndim_ = inShape0.size();
if (type() == PrimitiveType_MulGrad)
if (Type() == PrimitiveType_MulGrad)
arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul1L;
else if (type() == PrimitiveType_DivGrad)
else if (Type() == PrimitiveType_DivGrad)
arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv1L;
arithmeticParameter_->broadcasting_ = true;
arithmeticParameter_->ndim_ = inShape0.size();
int j = 0;
auto fillDimNum = inShape0.size() - inShape1.size();
for (unsigned int i = 0; i < inShape0.size(); i++) {
if (i < fillDimNum) {
arithmeticParameter_->in_shape1_[i] = 1;
} else {
arithmeticParameter_->in_shape1_[i] = inShape1[j++];
}
arithmeticParameter_->in_shape0_[i] = inShape0[i];
arithmeticParameter_->out_shape_[i] = outShape[i];
}
} else {
arithmeticParameter_->broadcasting_ = false;
for (unsigned int i = 0; i < inShape0.size(); i++) {
arithmeticParameter_->in_shape1_[i] = inShape1[i];
arithmeticParameter_->in_shape0_[i] = inShape0[i];
arithmeticParameter_->out_shape_[i] = outShape[i];
}
}
tile_data0 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()];
tile_data0 = new (std::nothrow) float[in_tensors_.at(0)->ElementsNum()];
if (tile_data0 == nullptr) {
MS_LOG(ERROR) << "new data0 fail!";
return RET_ERROR;
}
tile_data1 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()];
tile_data1 = new (std::nothrow) float[in_tensors_.at(0)->ElementsNum()];
if (tile_data1 == nullptr) {
MS_LOG(ERROR) << "new data1 fail!";
delete tile_data0;
return RET_ERROR;
}
if (type() == PrimitiveType_DivGrad) {
tile_data2 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()];
if (Type() == PrimitiveType_DivGrad) {
tile_data2 = new (std::nothrow) float[in_tensors_.at(0)->ElementsNum()];
if (tile_data2 == nullptr) {
MS_LOG(ERROR) << "new data2 fail!";
delete tile_data0;
......@@ -144,10 +77,6 @@ int ArithmeticGradCPUKernel::InferShape() {
}
}
dx1->set_shape(x1->shape());
dx2->set_shape(x2->shape());
dx1->set_data_type(dy->data_type());
dx2->set_data_type(dy->data_type());
return RET_OK;
}
......@@ -187,16 +116,16 @@ void ArithmeticGradCPUKernel::ArithmeticGradSub(float *dy, int dy_size, float *d
void ArithmeticGradCPUKernel::ArithmeticGradMul(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2,
int dx2_size) {
auto x1_data = reinterpret_cast<float *>(inputs_[1]->Data());
auto x2_data = reinterpret_cast<float *>(inputs_[2]->Data());
auto x1_data = reinterpret_cast<float *>(in_tensors_[1]->Data());
auto x2_data = reinterpret_cast<float *>(in_tensors_[2]->Data());
ElementMul(dy, x1_data, dx2, dy_size);
ElementMul(dy, x2_data, dx1, dy_size);
}
void ArithmeticGradCPUKernel::ArithmeticGradMul1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2,
int dx2_size) {
auto x1_data = reinterpret_cast<float *>(inputs_[1]->Data());
auto x2_data = reinterpret_cast<float *>(inputs_[2]->Data());
auto x1_data = reinterpret_cast<float *>(in_tensors_[1]->Data());
auto x2_data = reinterpret_cast<float *>(in_tensors_[2]->Data());
ElementMul(dy, x1_data, tile_data0, dy_size);
ReduceSumByAxes(tile_data0, arithmeticParameter_->in_shape0_, dx2, arithmeticParameter_->in_shape1_,
arithmeticParameter_->ndim_);
......@@ -206,8 +135,8 @@ void ArithmeticGradCPUKernel::ArithmeticGradMul1L(float *dy, int dy_size, float
void ArithmeticGradCPUKernel::ArithmeticGradMul2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2,
int dx2_size) {
auto x1_data = reinterpret_cast<float *>(inputs_[1]->Data());
auto x2_data = reinterpret_cast<float *>(inputs_[2]->Data());
auto x1_data = reinterpret_cast<float *>(in_tensors_[1]->Data());
auto x2_data = reinterpret_cast<float *>(in_tensors_[2]->Data());
ElementMul(dy, x2_data, tile_data0, dy_size);
ReduceSumByAxes(tile_data0, arithmeticParameter_->in_shape0_, dx1, arithmeticParameter_->in_shape1_,
arithmeticParameter_->ndim_);
......@@ -217,16 +146,16 @@ void ArithmeticGradCPUKernel::ArithmeticGradMul2L(float *dy, int dy_size, float
void ArithmeticGradCPUKernel::ArithmeticGradDiv(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2,
int dx2_size) {
auto x1 = reinterpret_cast<float *>(inputs_[1]->Data());
auto x2 = reinterpret_cast<float *>(inputs_[2]->Data());
auto x1 = reinterpret_cast<float *>(in_tensors_[1]->Data());
auto x2 = reinterpret_cast<float *>(in_tensors_[2]->Data());
ElementDiv(dy, x2, dx1, dy_size);
ElementMulAndDivNegSquare(dy, x1, x2, dx2, dy_size);
}
void ArithmeticGradCPUKernel::ArithmeticGradDiv1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2,
int dx2_size) {
auto x1_data = reinterpret_cast<float *>(inputs_[1]->Data());
auto x2_data = reinterpret_cast<float *>(inputs_[2]->Data());
auto x1_data = reinterpret_cast<float *>(in_tensors_[1]->Data());
auto x2_data = reinterpret_cast<float *>(in_tensors_[2]->Data());
ElementMul(x2_data, x2_data, dx2, dx2_size);
ElementMul(x1_data, dy, dx1, dy_size); // use dx1 buffer
......@@ -243,8 +172,8 @@ void ArithmeticGradCPUKernel::ArithmeticGradDiv1L(float *dy, int dy_size, float
void ArithmeticGradCPUKernel::ArithmeticGradDiv2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2,
int dx2_size) {
auto x1_data = reinterpret_cast<float *>(inputs_[1]->Data());
auto x2_data = reinterpret_cast<float *>(inputs_[2]->Data());
auto x1_data = reinterpret_cast<float *>(in_tensors_[1]->Data());
auto x2_data = reinterpret_cast<float *>(in_tensors_[2]->Data());
// dx1 = dy/x2
ElementDiv(dy, x2_data, tile_data0, dy_size); // first multiply into temp
......@@ -259,13 +188,13 @@ void ArithmeticGradCPUKernel::ArithmeticGradDiv2L(float *dy, int dy_size, float
int ArithmeticGradCPUKernel::ReSize() { return RET_OK; }
int ArithmeticGradCPUKernel::Run() {
auto dy = reinterpret_cast<float *>(inputs_[0]->Data());
auto dx1 = reinterpret_cast<float *>(outputs_[0]->Data());
auto dx2 = reinterpret_cast<float *>(outputs_[1]->Data());
auto dy = reinterpret_cast<float *>(in_tensors_[0]->Data());
auto dx1 = reinterpret_cast<float *>(out_tensors_[0]->Data());
auto dx2 = reinterpret_cast<float *>(out_tensors_[1]->Data());
size_t dy_size = inputs_.at(0)->ElementsNum();
size_t dx1_size = outputs_.at(0)->ElementsNum();
size_t dx2_size = outputs_[1]->ElementsNum();
size_t dy_size = in_tensors_.at(0)->ElementsNum();
size_t dx1_size = out_tensors_.at(0)->ElementsNum();
size_t dx2_size = out_tensors_[1]->ElementsNum();
(this->*arithmetic_grad_)(dy, dy_size, dx1, dx1_size, dx2, dx2_size);
return RET_OK;
}
......
......@@ -40,7 +40,7 @@ class ArithmeticGradCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), tile_data0(NULL), tile_data1(NULL), tile_data2(NULL) {
switch (type()) {
switch (Type()) {
case PrimitiveType_MulGrad:
arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul; // this will be adjusted in InferShape
break;
......
......@@ -27,33 +27,9 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_BiasGrad;
namespace mindspore::kernel {
int BiasGradCPUKernel::InferShape() {
if (1 != this->inputs_.size()) {
MS_LOG(ERROR) << "BiasGrad should have one input";
return RET_ERROR;
}
if (1 != this->outputs_.size()) {
MS_LOG(ERROR) << "BiasGrad should have one output";
return RET_ERROR;
}
auto *in0 = inputs_.front();
auto *out = outputs_.front();
MS_ASSERT(in0 != nullptr);
MS_ASSERT(out != nullptr);
auto inshape = in0->shape();
int ndim = inshape.size();
for (int i = 0; i < ndim - 1; i++) {
inshape[i] = 1;
}
out->set_shape(inshape);
out->set_data_type(in0->data_type());
return RET_OK;
}
int BiasGradCPUKernel::Init() {
MS_ASSERT(InferShape() == RET_OK);
auto dims = inputs_[0]->shape();
auto dims = in_tensors_[0]->shape();
bias_param->ndim_ = dims.size();
for (unsigned int i = 0; i < bias_param->ndim_; i++) {
bias_param->in_shape0_[i] = dims[i];
......@@ -75,8 +51,8 @@ int BiasGradCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
auto in = reinterpret_cast<float *>(inputs_.at(0)->Data());
auto out = reinterpret_cast<float *>(outputs_.at(0)->Data());
auto in = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
auto out = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
size_t nhw_size = 1;
size_t channels = bias_param->in_shape0_[bias_param->ndim_ - 1]; // C in NHWC
......
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BIAS_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BIAS_GRAD_H_
#include <vector>
#include "src/lite_kernel.h"
......@@ -35,7 +35,6 @@ class BiasGradCPUKernel : public LiteKernel {
~BiasGradCPUKernel() override = default;
int Init() override;
int InferShape();
int ReSize() override;
int Run() override;
......@@ -44,4 +43,4 @@ class BiasGradCPUKernel : public LiteKernel {
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BIAS_GRAD_H_
......@@ -14,11 +14,11 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32_grad/bn_grad.h"
#include <algorithm>
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/fp32_grad/bn_grad.h"
#include "nnacl/fp32_grad/batch_norm.h"
#include "include/errorcode.h"
......@@ -27,79 +27,103 @@ using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
// using mindspore::lite::REG_OP;
using mindspore::schema::PrimitiveType_BNGradInput;
using mindspore::schema::PrimitiveType_BNGrad;
/*
{dy}
{x }
{scale }
{save_mean }
{save_inv_variance }
*/
namespace mindspore::kernel {
int BNGradInputCPUKernel::Init() {
auto bn_param = reinterpret_cast<bnParameter *>(opParameter);
workspace_size = 5 * bn_param->channels;
workspace = new (std::nothrow) float[workspace_size];
if (workspace == nullptr) {
MS_LOG(ERROR) << "new workspace fail!";
return RET_ERROR;
}
if (2 != this->inputs_.size()) {
MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs";
return RET_ERROR;
#if 0
OpParameter *PopulateBNGradParameter(const lite::Primitive *primitive) {
BNGradParameter *param = new (std::nothrow) BNGradParameter();
if (param == nullptr) {
MS_LOG(ERROR) << "new Param for conv grad filter failed.";
return nullptr;
}
if (1 != this->outputs_.size()) {
MS_LOG(ERROR) << "Conv2d Grad should has one output";
param->op_parameter_.type_ = primitive->Type();
auto bngrad_primitive = primitive->Value()->value_as_BNGrad();
param->epsilon_ = bngrad_primitive->eps();
param->momentum_ = bngrad_primitive->momentum();
return reinterpret_cast<OpParameter *>(param);
}
#endif
int BNGradCPUKernel::Init() {
auto *input_x = in_tensors_.at(1);
int channels = input_x->shape().at(kNHWC_C);
workspace_size = 5 * channels;
workspace = new (std::nothrow) float[workspace_size];
if (workspace == nullptr) {
MS_LOG(ERROR) << "new workspace fail!";
return RET_ERROR;
}
auto *input_tensor = inputs_.at(0);
auto *out_tensor = outputs_.at(0);
auto in_shape = input_tensor->shape();
out_tensor->set_shape(in_shape);
out_tensor->set_data_type(input_tensor->data_type());
return RET_OK;
}
int BNGradInputCPUKernel::ReSize() { return RET_OK; }
int BNGradCPUKernel::ReSize() { return RET_OK; }
int BNGradInputCPUKernel::Run() {
auto *input_x = inputs_.at(0);
auto *input_yt = inputs_.at(1);
auto *input_scale = inputs_.at(2);
auto *output_grad = outputs_.at(0);
auto bn_param = reinterpret_cast<bnParameter *>(opParameter);
int batch = bn_param->batch;
int channels = bn_param->channels;
int spatial = bn_param->spatial;
float eps = bn_param->eps;
int BNGradCPUKernel::Run() {
// std::cout << "run succ" << std::endl;
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_);
auto *input_yt = in_tensors_.at(0);
auto *input_x = in_tensors_.at(1);
auto *input_scale = in_tensors_.at(2);
auto *output_dx = out_tensors_.at(0);
auto *output_scale = out_tensors_.at(1);
auto *output_bias = out_tensors_.at(2);
// Tensor *bias = input[5];
int batch = input_x->Batch();
int channels = input_x->Channel();
int spatial = input_x->Height() * input_x->Width();
float eps = bn_param->epsilon_;
std::fill(workspace, workspace + workspace_size, 0.f);
float *mean = workspace;
float *variance = mean + channels;
float *mean_delta = variance + channels;
float *invar = mean + channels;
float *mean_delta = invar + channels;
float *variance_delta = mean_delta + channels;
float *mean_add_delta = variance_delta + channels;
float *x = reinterpret_cast<float *>(input_x->Data());
float *yt = reinterpret_cast<float *>(input_yt->Data());
float *scale = reinterpret_cast<float *>(input_scale->Data());
float *out = reinterpret_cast<float *>(output_grad->Data());
float *dx = reinterpret_cast<float *>(output_dx->Data());
float *dscale = reinterpret_cast<float *>(output_scale->Data());
float *dbias = reinterpret_cast<float *>(output_bias->Data());
std::copy(yt, yt + batch * channels * spatial, out);
meanVar(x, batch, spatial, channels, mean, variance);
scaleBias(scale, batch, channels, spatial, out);
meanDelta(out, spatial, channels, eps, variance, mean_delta);
varianceDelta(x, out, mean, variance, batch, channels, spatial, eps, variance_delta);
std::copy(yt, yt + batch * channels * spatial, dx);
meanVar(x, batch, spatial, channels, eps, mean, invar);
scaleBias(scale, batch, channels, spatial, dx);
meanDelta(dx, spatial, channels, invar, mean_delta);
varianceDelta(x, dx, mean, invar, batch, channels, spatial, variance_delta);
meanAdd(x, mean, variance_delta, batch, channels, spatial, mean_add_delta, mean_delta);
NormalizeDelta(x, mean, variance, mean_delta, variance_delta, batch, channels, eps, spatial, out);
NormalizeDelta(x, mean, invar, mean_delta, variance_delta, batch, channels, spatial, dx);
// dbias
sumSpatialBatch(yt, batch * spatial, channels, dbias);
// dscale
backwardScale(x, mean, invar, yt, batch, channels, spatial, dscale);
return RET_OK;
}
kernel::LiteKernel *CpuBNGradInputFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
kernel::LiteKernel *CpuBNGradFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_BNGradInput);
auto *kernel = new (std::nothrow) BNGradInputCPUKernel(opParameter, inputs, outputs, ctx, primitive);
MS_ASSERT(desc.type == schema::PrimitiveType_BNGrad);
auto *kernel = new (std::nothrow) BNGradCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new BNGradInputCPUKernel fail!";
MS_LOG(ERROR) << "new BNGradCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
......@@ -112,5 +136,5 @@ kernel::LiteKernel *CpuBNGradInputFp32KernelCreator(const std::vector<lite::tens
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BNGradInput, CpuBNGradInputFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BNGrad, CpuBNGradFp32KernelCreator)
} // namespace mindspore::kernel
......@@ -14,21 +14,25 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_
#include <vector>
#include "src/lite_kernel.h"
#include "ir/anf.h"
namespace mindspore::kernel {
class BNGradInputCPUKernel : public LiteKernel {
class BNGradCPUKernel : public LiteKernel {
public:
explicit BNGradInputCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
explicit BNGradCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~BNGradInputCPUKernel() override { delete workspace; }
~BNGradCPUKernel() override { delete workspace; }
int Init() override;
int ReSize() override;
......@@ -38,5 +42,8 @@ class BNGradInputCPUKernel : public LiteKernel {
float *workspace;
int workspace_size;
};
// OpParameter *PopulateBNGradParameter(const lite::Primitive *primitive);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32_grad/convolution.h"
#include "nnacl/fp32_grad/pack_ext.h"
#include "nnacl/fp32_grad/gemm.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int ConvolutionTrainCPUKernel::Init() {
auto conv_param_ = reinterpret_cast<ConvParameter *>(op_parameter_);
auto *input_x = in_tensors_.at(kInputIndex);
auto *input_weight = in_tensors_.at(kWeightIndex);
auto *out_y = out_tensors_.at(kOutputIndex);
conv_param_->output_batch_ = out_y->shape().at(kNHWC_N);
conv_param_->input_batch_ = input_x->shape().at(kNHWC_N);
conv_param_->input_h_ = input_x->shape().at(kNHWC_H);
conv_param_->input_w_ = input_x->shape().at(kNHWC_W);
conv_param_->output_h_ = out_y->shape().at(kNHWC_H);
conv_param_->output_w_ = out_y->shape().at(kNHWC_W);
conv_param_->input_channel_ = input_x->shape().at(kNHWC_C);
conv_param_->output_channel_ = input_weight->shape().at(kNHWC_N);
conv_param_->kernel_h_ = input_weight->shape().at(kNHWC_H);
conv_param_->kernel_w_ = input_weight->shape().at(kNHWC_W);
int ws_size = conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->kernel_h_ * conv_param_->kernel_w_ *
conv_param_->input_channel_ / conv_param_->group_;
workspace = new float[ws_size];
return RET_OK;
}
int ConvolutionTrainCPUKernel::ReSize() { return RET_OK; }
int ConvolutionTrainCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto conv_param_ = reinterpret_cast<ConvParameter *>(op_parameter_);
auto *input_x = in_tensors_.at(kInputIndex);
auto *input_w = in_tensors_.at(kWeightIndex);
auto *out_y = out_tensors_.at(kOutputIndex);
auto x_addr = reinterpret_cast<float *>(input_x->Data());
auto y_addr = reinterpret_cast<float *>(out_y->Data());
auto w_addr = reinterpret_cast<float *>(input_w->Data());
int i, j;
int nweights = input_w->ElementsNum();
int in_ch = conv_param_->input_channel_;
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
int k_h = conv_param_->kernel_h_;
int k_w = conv_param_->kernel_w_;
int batch = conv_param_->output_batch_;
int out_ch = conv_param_->output_channel_; // out_y->shape()[3];
int groups = conv_param_->group_;
int out_h = conv_param_->output_h_;
int out_w = conv_param_->output_w_;
int m = out_h * out_w;
int n = out_ch / groups;
int k = k_h * k_w * in_ch / groups;
memset(y_addr, 0, out_y->Size());
for (i = 0; i < batch; ++i) {
for (j = 0; j < groups; ++j) {
float *mat_a = workspace;
float *mat_b = w_addr + j * nweights / groups;
float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups);
float *im = x_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups);
im2col_hwc(im, mat_a, conv_param_);
gemm(0, 1, m, n, k, 1, mat_a, k, mat_b, k, 1, mat_c, out_ch);
}
}
// std::cout << "run succ" << std::endl;
return RET_OK;
}
kernel::LiteKernel *CpuConvTrainFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto *kernel = new (std::nothrow) ConvolutionTrainCPUKernel(opParameter, inputs, outputs, ctx, primitive);
MS_ASSERT(kernel != nullptr);
auto ret = kernel->Init();
if (RET_OK != ret) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
} // namespace mindspore::kernel
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_H_
#include <vector>
#include "src/lite_kernel.h"
#include "ir/anf.h"
namespace mindspore::kernel {
class ConvolutionTrainCPUKernel : public LiteKernel {
public:
explicit ConvolutionTrainCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConvolutionTrainCPUKernel() override { delete [] workspace; }
int Init() override;
int ReSize() override;
int Run() override;
private:
float *workspace;
};
kernel::LiteKernel *CpuConvTrainFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::PrimitiveC *primitive);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_H_
......@@ -33,30 +33,24 @@ int ConvolutionGradFilterCPUKernel::Init() {
// x is in input 1
// dw is output 0
if (2 != this->inputs_.size()) {
MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs";
return RET_ERROR;
}
if (1 != this->outputs_.size()) {
MS_LOG(ERROR) << "Conv2d Grad should has one output";
return RET_ERROR;
}
auto *input_tensor = inputs_.at(1);
MS_ASSERT(input_tensor != nullptr);
auto *dy = inputs_.at(0);
MS_ASSERT(dy != nullptr);
auto *weight_tensor = outputs_.at(0);
auto *x_tensor = in_tensors_.at(1);
MS_ASSERT(x_tensor != nullptr);
auto *dy_tensor = in_tensors_.at(0);
MS_ASSERT(dy_tensor != nullptr);
auto *weight_tensor = out_tensors_.at(0);
MS_ASSERT(weight_tensor != nullptr);
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
conv_param->output_batch_ = this->inputs_.at(0)->shape().at(kNHWC_N);
conv_param->input_batch_ = this->inputs_.at(1)->shape().at(kNHWC_N);
conv_param->input_h_ = this->inputs_.at(1)->shape().at(kNHWC_H);
conv_param->input_w_ = this->inputs_.at(1)->shape().at(kNHWC_W);
// assume OutCh|kh|kw|In
conv_param->input_channel_ = this->inputs_.at(1)->shape().at(kNHWC_C);
conv_param->output_channel_ = this->outputs_.at(0)->shape().at(kNHWC_N);
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_);
conv_param->output_batch_ = dy_tensor->shape().at(kNHWC_N);
conv_param->input_batch_ = x_tensor->shape().at(kNHWC_N);
conv_param->input_h_ = x_tensor->shape().at(kNHWC_H);
conv_param->input_w_ = x_tensor->shape().at(kNHWC_W);
// assume OutCh|kh|kw|InCh
conv_param->input_channel_ = x_tensor->shape().at(kNHWC_C);
conv_param->output_channel_ = dy_tensor->shape().at(kNHWC_C);
// TBD
conv_param->output_h_ = dy_tensor->shape()[kNHWC_H];
conv_param->output_w_ = dy_tensor->shape()[kNHWC_W];
int ws_size = conv_param->output_h_ * conv_param->output_w_ * conv_param->kernel_h_ * conv_param->kernel_w_ *
conv_param->input_channel_ / conv_param->group_;
......@@ -67,34 +61,21 @@ int ConvolutionGradFilterCPUKernel::Init() {
return RET_ERROR;
}
int output_w = 0;
int output_h = 0;
output_h = dy->shape()[kNHWC_H];
output_w = dy->shape()[kNHWC_W];
std::vector<int> out_shape(4);
out_shape.at(0) = conv_param->output_channel_;
out_shape.at(1) = conv_param->kernel_h_;
out_shape.at(2) = conv_param->kernel_w_;
out_shape.at(3) = conv_param->input_channel_ / conv_param->group_;
// weight is output
weight_tensor->set_shape(out_shape);
weight_tensor->set_data_type(input_tensor->data_type());
conv_param->output_h_ = output_h;
conv_param->output_w_ = output_w;
return RET_OK;
}
int ConvolutionGradFilterCPUKernel::ReSize() { return 0; }
int ConvolutionGradFilterCPUKernel::ReSize() { return RET_OK; }
int ConvolutionGradFilterCPUKernel::Run() {
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
auto *input_dy = inputs_.at(0);
auto *input_x = inputs_.at(1);
auto *out_dw = outputs_.at(0);
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_);
auto *input_dy = in_tensors_.at(0);
auto *input_x = in_tensors_.at(1);
auto *out_dw = out_tensors_.at(0);
auto x_addr = reinterpret_cast<float *>(input_x->Data());
auto dy_addr = reinterpret_cast<float *>(input_dy->Data());
......@@ -135,7 +116,48 @@ int ConvolutionGradFilterCPUKernel::Run() {
// std::cout << "run succ" << std::endl;
return RET_OK;
}
#if 0
OpParameter *PopulateConvolutionGradFilterParameter(const lite::Primitive *primitive) {
ConvParameter *param = new (std::nothrow) ConvParameter();
if (param == nullptr) {
MS_LOG(ERROR) << "new Param for conv grad filter failed.";
return nullptr;
}
param->op_parameter_.type_ = primitive->Type();
auto convg_primitive = primitive->Value()->value_as_Conv2DGradFilter();
param->kernel_h_ = convg_primitive->kernelH();
param->kernel_w_ = convg_primitive->kernelW();
param->stride_h_ = convg_primitive->strideH();
param->stride_w_ = convg_primitive->strideW();
param->dilation_h_ = convg_primitive->dilateH();
param->dilation_w_ = convg_primitive->dilateW();
param->pad_h_ = convg_primitive->padUp();
param->pad_w_ = convg_primitive->padLeft();
param->pad_u_ = convg_primitive->padUp();
param->pad_d_ = convg_primitive->padDown();
param->pad_l_ = convg_primitive->padLeft();
param->pad_r_ = convg_primitive->padRight();
param->group_ = convg_primitive->group();
auto act_type = convg_primitive->activationType();
switch (act_type) {
case schema::ActivationType_RELU:
param->is_relu_ = true;
param->is_relu6_ = false;
break;
case schema::ActivationType_RELU6:
param->is_relu_ = false;
param->is_relu6_ = true;
break;
default:
param->is_relu_ = false;
param->is_relu6_ = false;
break;
}
return reinterpret_cast<OpParameter *>(param);
}
#endif
kernel::LiteKernel *CpuConvGradFilterFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
......
/**
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
......@@ -28,15 +28,17 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConvolutionGradFilterCPUKernel() override { delete workspace; }
~ConvolutionGradFilterCPUKernel() override { delete [] workspace; }
int Init() override;
int ReSize() override;
int Run() override;
private:
float *workspace;
float *workspace = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_
......@@ -29,23 +29,14 @@ using mindspore::schema::PrimitiveType_Conv2DGradInput;
namespace mindspore::kernel {
int ConvolutionGradInputCPUKernel::Init() {
if (2 != this->inputs_.size()) {
MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs";
return RET_ERROR;
}
if (1 != this->outputs_.size()) {
MS_LOG(ERROR) << "Conv2d Grad should has one output";
return RET_ERROR;
}
auto *dy_tensor = inputs_.at(kInputIndex);
auto *dy_tensor = in_tensors_.at(kInputIndex);
MS_ASSERT(dy_tensor != nullptr);
auto *weight_tensor = inputs_.at(kWeightIndex);
auto *weight_tensor = in_tensors_.at(kWeightIndex);
MS_ASSERT(weight_tensor != nullptr);
auto *dx_tensor = outputs_.at(kOutputIndex);
auto *dx_tensor = out_tensors_.at(kOutputIndex);
MS_ASSERT(dx_tensor != nullptr);
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_);
conv_param->output_batch_ = dx_tensor->shape()[(kNHWC_N)];
conv_param->input_batch_ = dy_tensor->shape()[(kNHWC_N)];
......@@ -74,10 +65,16 @@ int ConvolutionGradInputCPUKernel::Init() {
int ConvolutionGradInputCPUKernel::ReSize() { return 0; }
int ConvolutionGradInputCPUKernel::Run() {
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
auto *input_dy = inputs_.at(0);
auto *input_w = inputs_.at(1);
auto *out_dx = outputs_.at(0);
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_);
auto *input_dy = in_tensors_.at(0);
auto *input_w = in_tensors_.at(1);
auto *out_dx = out_tensors_.at(0);
auto dy_addr = reinterpret_cast<float *>(input_dy->Data());
auto w_addr = reinterpret_cast<float *>(input_w->Data());
......@@ -116,6 +113,49 @@ int ConvolutionGradInputCPUKernel::Run() {
return 0;
}
#if 0
OpParameter *PopulateConvolutionGradInputParameter(const lite::Primitive *primitive) {
ConvParameter *param = new (std::nothrow) ConvParameter();
if (param == nullptr) {
MS_LOG(ERROR) << "new Param for conv grad input failed.";
return nullptr;
}
param->op_parameter_.type_ = primitive->Type();
auto convg_primitive = primitive->Value()->value_as_Conv2DGradInput();
param->kernel_h_ = convg_primitive->kernelH();
param->kernel_w_ = convg_primitive->kernelW();
param->stride_h_ = convg_primitive->strideH();
param->stride_w_ = convg_primitive->strideW();
param->dilation_h_ = convg_primitive->dilateH();
param->dilation_w_ = convg_primitive->dilateW();
param->pad_h_ = convg_primitive->padUp();
param->pad_w_ = convg_primitive->padLeft();
param->pad_u_ = convg_primitive->padUp();
param->pad_d_ = convg_primitive->padDown();
param->pad_l_ = convg_primitive->padLeft();
param->pad_r_ = convg_primitive->padRight();
param->group_ = convg_primitive->group();
auto act_type = convg_primitive->activationType();
switch (act_type) {
case schema::ActivationType_RELU:
param->is_relu_ = true;
param->is_relu6_ = false;
break;
case schema::ActivationType_RELU6:
param->is_relu_ = false;
param->is_relu6_ = true;
break;
default:
param->is_relu_ = false;
param->is_relu6_ = false;
break;
}
return reinterpret_cast<OpParameter *>(param);
}
#endif
kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
......
......@@ -28,7 +28,7 @@ class ConvolutionGradInputCPUKernel : public LiteKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConvolutionGradInputCPUKernel() override { delete workspace; }
~ConvolutionGradInputCPUKernel() override { delete [] workspace; }
int Init() override;
int ReSize() override;
......@@ -37,6 +37,9 @@ class ConvolutionGradInputCPUKernel : public LiteKernel {
private:
float *workspace;
};
// OpParameter *PopulateConvolutionGradInputParameter(const lite::Primitive *primitive);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/runtime/kernel/arm/fp32_grad/depend.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Depend;
namespace mindspore::kernel {
int DependCPUKernel::Init() {
return RET_OK;
}
int DependCPUKernel::ReSize() { return 0; }
int DependCPUKernel::Run() {
#if 0
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
auto in = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
auto out = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
memcpy(out, in, in_tensors_.at(0)->Size());
#endif
return RET_OK;
}
kernel::LiteKernel *CpuDependFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Depend);
auto *kernel =
new (std::nothrow) DependCPUKernel(opParameter, inputs, outputs, ctx, primitive);
MS_ASSERT(kernel != nullptr);
auto ret = kernel->Init();
if (RET_OK != ret) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Depend, CpuDependFp32KernelCreator)
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DEPEND_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DEPEND_H_
#include <vector>
#include "src/lite_kernel.h"
#include "ir/anf.h"
#include "nnacl/fp32/arithmetic.h"
namespace mindspore::kernel {
class DependCPUKernel : public LiteKernel {
public:
explicit DependCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param = parameter;
}
~DependCPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
private:
OpParameter *param;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DEPEND_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_MAKE_TUPLE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_MAKE_TUPLE_H_
#include <vector>
#include "src/lite_kernel.h"
#include "ir/anf.h"
#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h"
namespace mindspore::kernel {
class MakeTupleCPUKernel : public LiteKernel {
public:
explicit MakeTupleCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param = parameter;
}
~MakeTupleCPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
private:
OpParameter *param;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_MAKE_TUPLE_H_
......@@ -20,6 +20,7 @@
#include "nnacl/fp32/pooling.h"
#include "nnacl/fp32_grad/pooling_grad.h"
#include "include/errorcode.h"
// #include "src/train/ops/train_ops.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
......@@ -29,9 +30,15 @@ using mindspore::schema::PrimitiveType_PoolingGrad;
namespace mindspore::kernel {
int PoolingGradCPUKernel::Init() {
PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(opParameter);
PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(op_parameter_);
auto in_shape = inputs_.at(0)->shape();
auto in_shape = in_tensors_.at(0)->shape();
auto out_shape = in_tensors_.at(1)->shape();
if (pool_param->pool_mode_ == PoolMode_AvgPool) {
in_shape = in_tensors_.at(1)->shape();
out_shape = in_tensors_.at(0)->shape();
}
int input_h = in_shape.at(1);
int input_w = in_shape.at(2);
......@@ -40,25 +47,39 @@ int PoolingGradCPUKernel::Init() {
pool_param->window_h_ = input_h;
}
pool_param->input_h_ = in_shape[kNHWC_H];
pool_param->input_w_ = in_shape[kNHWC_W];
pool_param->input_batch_ = in_shape[kNHWC_N];
pool_param->input_channel_ = in_shape[kNHWC_C];
// Emir -- here I assume we get the outputshape in the output tensor
auto *out_tensor = outputs_.front();
auto out_shape = out_tensor->shape();
// auto *out_tensor = out_tensors_.front();
// auto out_shape = in_tensors_.at(1)->shape();
pool_param->output_h_ = out_shape[kNHWC_H];
pool_param->output_w_ = out_shape[kNHWC_W];
pool_param->output_batch_ = out_shape[kNHWC_N];
pool_param->output_channel_ = out_shape[kNHWC_C];
out_tensor->set_shape(out_shape);
out_tensor->set_data_type(inputs_.at(0)->data_type());
return RET_OK;
}
int PoolingGradCPUKernel::ReSize() { return RET_OK; }
int PoolingGradCPUKernel::Run() {
PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(opParameter);
auto input_ptr = reinterpret_cast<float *>(inputs_.at(0)->Data());
auto output_ptr = reinterpret_cast<float *>(outputs_.at(0)->Data());
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(op_parameter_);
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
if (pool_param->pool_mode_ == PoolMode_MaxPool) {
auto ind = reinterpret_cast<int *>(inputs_.at(1)->Data());
MaxPoolingGrad(input_ptr, ind, output_ptr, pool_param);
auto dx_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->Data());
auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->Data());
MaxPoolingGrad(input_ptr, dx_ptr, dy_ptr, output_ptr, pool_param);
} else {
AvgPoolingGrad(input_ptr, output_ptr, pool_param);
}
......
......@@ -43,6 +43,7 @@ class PoolingGradCPUKernel : public LiteKernel {
private:
uint8_t data_shape_{0};
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_
......@@ -31,10 +31,10 @@ int PowerGradCPUKernel::Init() { return RET_OK; }
int PowerGradCPUKernel::ReSize() { return RET_OK; }
int PowerGradCPUKernel::Run() {
auto dy_addr = reinterpret_cast<float *>(inputs_.at(0)->Data());
auto x_addr = reinterpret_cast<float *>(inputs_.at(1)->Data());
auto dx_addr = reinterpret_cast<float *>(outputs_.at(0)->Data());
auto size = inputs_.at(0)->ElementsNum();
auto dy_addr = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
auto x_addr = reinterpret_cast<float *>(in_tensors_.at(1)->Data());
auto dx_addr = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
auto size = in_tensors_.at(0)->ElementsNum();
float exp = power_ - 1;
Power(x_addr, &exp, dx_addr, size, scale_, shift_, true);
......@@ -47,6 +47,7 @@ int PowerGradCPUKernel::Run() {
return RET_OK;
}
kernel::LiteKernel *CpuPowerGradFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
......
......@@ -45,6 +45,7 @@ class PowerGradCPUKernel : public LiteKernel {
float scale_;
float shift_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_
......@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <math.h>
#include "src/kernel_registry.h"
#include "nnacl/softmax_parameter.h"
#include "nnacl/fp32/softmax.h"
......@@ -46,9 +47,10 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int
output[0] = total_loss / param->batch_size_;
}
void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses,
void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, float *grads,
float *output) const {
size_t row_start = 0;
float total_loss = 0;
for (int i = 0; i < param->batch_size_; ++i) {
if (labels[i] < 0) {
MS_LOG(EXCEPTION) << "label value must >= 0";
......@@ -56,78 +58,88 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *la
size_t label = labels[i];
if (label > param->number_of_classes_) {
MS_LOG(EXCEPTION) << "error label input!";
}
for (size_t j = 0; j < param->number_of_classes_; ++j) {
size_t index = row_start + j;
if (j == label) {
output[index] = (losses[index] - 1) / param->batch_size_;
} else {
output[index] = losses[index] / param->batch_size_;
} else {
total_loss -= logf(losses[i * param->number_of_classes_ + label]);
for (size_t j = 0; j < param->number_of_classes_; ++j) {
size_t index = row_start + j;
if (j == label) {
grads[index] = (losses[index] - 1) / param->batch_size_;
} else {
grads[index] = losses[index] / param->batch_size_;
}
}
}
row_start += param->number_of_classes_;
}
output[0] = total_loss / param->batch_size_;
}
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
auto ins = reinterpret_cast<float *>(inputs_.at(0)->Data());
auto labels = reinterpret_cast<int *>(inputs_.at(1)->Data());
auto out = reinterpret_cast<float *>(outputs_.at(1)->Data());
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return ret;
}
auto ins = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
auto labels = reinterpret_cast<int *>(in_tensors_.at(1)->Data());
float *out = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
float *grads = NULL;
if (is_train()) { // outputs_.size() > 1)
grads = reinterpret_cast<float *>(outputs_.at(0)->Data());
if (is_train() && out_tensors_.size() > 1) {
grads = reinterpret_cast<float *>(out_tensors_.at(1)->Data());
}
size_t data_size = inputs_.at(0)->ElementsNum();
size_t data_size = in_tensors_.at(0)->ElementsNum();
float *losses = new (std::nothrow) float[data_size];
if (losses == nullptr) {
MS_LOG(ERROR) << "losses is null";
return nullptr;
return RET_ERROR;
}
std::fill(losses, losses + data_size, 0);
MS_ASSERT(out != nullptr);
MS_ASSERT(labels != nullptr);
MS_ASSERT(ins != nullptr);
SoftmaxParameter sm_params;
sm_params.n_dim_ = param->n_dim_;
sm_params.element_size_ = data_size;
sm_params.axis_ = 0;
for (int i = 0; i < 4; i++) // softmax has only 4 params in shape
sm_params.input_shape_[i] = param->input_shape_[i];
float sum_data[sm_params.input_shape_[sm_params.axis_]] = {0};
std::fill(sum_data, sum_data + sm_params.input_shape_[sm_params.axis_], 0);
Softmax(ins, losses, sum_data, &sm_params);
std::fill(losses_, losses_ + data_size, 0);
std::fill(sum_data_, sum_data_ + sm_params_.input_shape_[0], 0);
Softmax(ins, losses_, sum_data_, &sm_params_);
if (is_train()) {
GradPostExecute(labels, losses, grads);
} else {
ForwardPostExecute(labels, losses, out);
GradPostExecute(labels, losses_, grads, out);
} else if (out != nullptr) {
ForwardPostExecute(labels, losses_, out);
}
return RET_OK;
}
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
SetNeedReInit();
return RET_OK;
}
auto dims = inputs_[0]->shape();
// if (context_ && context_->infer_shape_interrupt_ && !context_->running_) {
// set_need_reinit();
// return RET_OK;
// }
auto dims = in_tensors_[0]->shape();
param->n_dim_ = 2;
param->number_of_classes_ = dims[1];
param->batch_size_ = dims[0];
for (unsigned int i = 0; i < dims.size(); i++) param->input_shape_[i] = dims[i];
if (2 != this->inputs_.size()) {
if (2 != this->in_tensors_.size()) {
MS_LOG(ERROR) << "softmax entropy loss should have two inputs";
return RET_ERROR;
}
auto *in0 = inputs_.front();
auto *in0 = in_tensors_.front();
if (in0 == nullptr) {
MS_LOG(ERROR) << "softmax etropy loss in0 have no data";
return RET_ERROR;
}
size_t data_size = in_tensors_.at(0)->ElementsNum();
losses_ = new (std::nothrow) float[data_size];
sum_data_ = new (std::nothrow) float[dims[0]];
MS_ASSERT(losses_ != nullptr);
MS_ASSERT(sum_data_ != nullptr);
sm_params_.n_dim_ = 2;
sm_params_.element_size_ = data_size;
sm_params_.axis_ = 1;
for (int i = 0; i < dims.size(); i++) sm_params_.input_shape_[i] = dims[i];
return RET_OK;
}
......
......@@ -14,31 +14,32 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/train/loss_kernel.h"
#include "ir/anf.h"
#include "nnacl/fp32_grad/softmax_grad.h"
#include "nnacl/fp32/arithmetic.h"
#include "nnacl/softmax_parameter.h"
namespace mindspore::kernel {
class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LiteKernel {
class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel {
public:
explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter,
const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
: LossKernel(parameter, inputs, outputs, ctx, primitive) {
param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter);
}
~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override = default;
~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override { delete[] losses_; delete[] sum_data_; }
void ForwardPostExecute(const int *labels, const float *losses, float *output) const;
void GradPostExecute(const int *labels, const float *losses, float *output) const;
void GradPostExecute(const int *labels, const float *losses, float* grads, float *output) const;
int Init() override;
int ReSize() override;
......@@ -46,7 +47,11 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LiteKernel {
private:
SoftmaxCrossEntropyParameter *param;
SoftmaxParameter sm_params_;
float *losses_ = nullptr;
float *sum_data_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/runtime/kernel/arm/fp32_grad/tuple_getitem.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_TupleGetItem;
namespace mindspore::kernel {
int TupleGetItemCPUKernel::Init() {
return RET_OK;
}
int TupleGetItemCPUKernel::ReSize() { return 0; }
int TupleGetItemCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
auto in = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
auto out = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
memcpy(out, in, in_tensors_.at(0)->Size());
return RET_OK;
}
kernel::LiteKernel *CpuTupleGetItemFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_TupleGetItem);
auto *kernel =
new (std::nothrow) TupleGetItemCPUKernel(opParameter, inputs, outputs, ctx, primitive);
MS_ASSERT(kernel != nullptr);
auto ret = kernel->Init();
if (RET_OK != ret) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TupleGetItem, CpuTupleGetItemFp32KernelCreator)
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_TUPLE_GETITEM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_TUPLE_GETITEM_H_
#include <vector>
#include "src/lite_kernel.h"
#include "ir/anf.h"
#include "nnacl/fp32/arithmetic.h"
namespace mindspore::kernel {
class TupleGetItemCPUKernel : public LiteKernel {
public:
explicit TupleGetItemCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param = parameter;
}
~TupleGetItemCPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
private:
OpParameter *param;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_TUPLE_GETITEM_H_
......@@ -94,8 +94,10 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *
inputs.emplace_back(tensors->at(size_t(inIndexes->GetAs<uint32_t>(j))));
}
auto outIndexes = cNode->outputIndex();
for (size_t j = 0; j < outIndexes->size(); j++) {
outputs.emplace_back(tensors->at(size_t(outIndexes->GetAs<uint32_t>(j))));
if (outIndexes != nullptr) {
for (size_t j = 0; j < outIndexes->size(); j++) {
outputs.emplace_back(tensors->at(size_t(outIndexes->GetAs<uint32_t>(j))));
}
}
auto *primitive = model->GetOp(cNode->name()->str());
if (primitive == nullptr) {
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_TRAIN_LOSS_KERNEL_H_
#define MINDSPORE_LITE_SRC_TRAIN_LOSS_KERNEL_H_
#include <vector>
#include "src/lite_kernel.h"
namespace mindspore::kernel {
class LossKernel : public LiteKernel {
public:
LossKernel() = default;
explicit LossKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
const lite::Context *ctx,
const lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~LossKernel() = default;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_TRAIN_LOSS_KERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/populate_parameter.h"
#include "src/train/train_populate_parameter.h"
#include "src/ops/pooling_grad.h"
#include "nnacl/pooling_parameter.h"
#include "src/ops/softmax_cross_entropy.h"
#include "nnacl/fp32_grad/softmax_grad.h"
#include "src/ops/activation_grad.h"
#include "nnacl/fp32/activation.h"
#include "src/ops/conv2d_grad_filter.h"
#include "src/ops/conv2d_grad_input.h"
#include "nnacl/conv_parameter.h"
#include "src/ops/power_grad.h"
#include "nnacl/power_parameter.h"
namespace mindspore::kernel {
OpParameter *DefaultPopulateParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
return nullptr;
}
OpParameter *param = new (std::nothrow) OpParameter();
if (param == nullptr) {
MS_LOG(ERROR) << "new Param for primitive failed.";
return nullptr;
}
param->type_ = primitive->Type();
return param;
}
OpParameter *PopulateSoftmaxCrossEntropyParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
return nullptr;
}
SoftmaxCrossEntropyParameter *sce_param = new (std::nothrow) SoftmaxCrossEntropyParameter();
if (sce_param == nullptr) {
MS_LOG(ERROR) << "new SoftmaxCrossEntropyParameter failed.";
return nullptr;
}
sce_param->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(sce_param);
}
OpParameter *PopulatePoolingGradParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
return nullptr;
}
PoolingParameter *pooling_param = new (std::nothrow) PoolingParameter();
if (pooling_param == nullptr) {
MS_LOG(ERROR) << "new PoolingParameter failed.";
return nullptr;
}
pooling_param->op_parameter_.type_ = primitive->Type();
auto pooling_primitive =
reinterpret_cast<mindspore::lite::PoolingGrad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
pooling_param->global_ = pooling_primitive->GetGlobal();
pooling_param->window_w_ = pooling_primitive->GetWindowW();
pooling_param->window_h_ = pooling_primitive->GetWindowH();
pooling_param->pad_u_ = pooling_primitive->GetPadUp();
pooling_param->pad_d_ = pooling_primitive->GetPadDown();
pooling_param->pad_l_ = pooling_primitive->GetPadLeft();
pooling_param->pad_r_ = pooling_primitive->GetPadRight();
pooling_param->stride_w_ = pooling_primitive->GetStrideW();
pooling_param->stride_h_ = pooling_primitive->GetStrideH();
pooling_param->pool_mode_ = PoolMode_No;
pooling_param->round_mode_ = RoundMode_No;
switch (pooling_primitive->GetPoolingMode()) {
case schema::PoolMode_MAX_POOLING:
pooling_param->pool_mode_ = PoolMode_MaxPool;
break;
case schema::PoolMode_MEAN_POOLING:
pooling_param->pool_mode_ = PoolMode_AvgPool;
break;
default:
break;
}
switch (pooling_primitive->GetRoundMode()) {
case schema::RoundMode_FLOOR:
pooling_param->round_mode_ = RoundMode_Floor;
break;
case schema::RoundMode_CEIL:
pooling_param->round_mode_ = RoundMode_Ceil;
break;
default:
break;
}
return reinterpret_cast<OpParameter *>(pooling_param);
}
OpParameter *PopulateActivationGradParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
return nullptr;
}
ActivationParameter *act_param = new (std::nothrow) ActivationParameter();
if (act_param == nullptr) {
MS_LOG(ERROR) << "new ActivationParameter failed.";
return nullptr;
}
act_param->op_parameter_.type_ = primitive->Type();
auto activation =
reinterpret_cast<mindspore::lite::ActivationGrad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
act_param->type_ = static_cast<int>(activation->GetType());
act_param->alpha_ = activation->GetAlpha();
return reinterpret_cast<OpParameter *>(act_param);
}
OpParameter *PopulateConvolutionGradFilterParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
return nullptr;
}
ConvParameter *param = new (std::nothrow) ConvParameter();
if (param == nullptr) {
MS_LOG(ERROR) << "new Param for conv grad filter failed.";
return nullptr;
}
param->op_parameter_.type_ = primitive->Type();
auto convg_primitive =
reinterpret_cast<mindspore::lite::Conv2DGradFilter *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
param->kernel_h_ = convg_primitive->GetKernelH();
param->kernel_w_ = convg_primitive->GetKernelW();
param->stride_h_ = convg_primitive->GetStrideH();
param->stride_w_ = convg_primitive->GetStrideW();
param->dilation_h_ = convg_primitive->GetDilateH();
param->dilation_w_ = convg_primitive->GetDilateW();
param->pad_u_ = convg_primitive->GetPadUp();
param->pad_d_ = convg_primitive->GetPadDown();
param->pad_l_ = convg_primitive->GetPadLeft();
param->pad_r_ = convg_primitive->GetPadRight();
param->group_ = convg_primitive->GetGroup();
param->act_type_ = ActType_No;
switch (convg_primitive->GetActivationType()) {
case schema::ActivationType_RELU:
param->act_type_ = ActType_Relu;
break;
case schema::ActivationType_RELU6:
param->act_type_ = ActType_Relu6;
break;
default:
break;
}
return reinterpret_cast<OpParameter *>(param);
}
OpParameter *PopulateConvolutionGradInputParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
return nullptr;
}
ConvParameter *param = new (std::nothrow) ConvParameter();
if (param == nullptr) {
MS_LOG(ERROR) << "new Param for conv grad filter failed.";
return nullptr;
}
param->op_parameter_.type_ = primitive->Type();
auto convg_primitive =
reinterpret_cast<mindspore::lite::Conv2DGradInput *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
param->kernel_h_ = convg_primitive->GetKernelH();
param->kernel_w_ = convg_primitive->GetKernelW();
param->stride_h_ = convg_primitive->GetStrideH();
param->stride_w_ = convg_primitive->GetStrideW();
param->dilation_h_ = convg_primitive->GetDilateH();
param->dilation_w_ = convg_primitive->GetDilateW();
param->pad_u_ = convg_primitive->GetPadUp();
param->pad_d_ = convg_primitive->GetPadDown();
param->pad_l_ = convg_primitive->GetPadLeft();
param->pad_r_ = convg_primitive->GetPadRight();
param->group_ = convg_primitive->GetGroup();
param->act_type_ = ActType_No;
switch (convg_primitive->GetActivationType()) {
case schema::ActivationType_RELU:
param->act_type_ = ActType_Relu;
break;
case schema::ActivationType_RELU6:
param->act_type_ = ActType_Relu6;
break;
default:
break;
}
return reinterpret_cast<OpParameter *>(param);
}
OpParameter *PopulatePowerGradParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
return nullptr;
}
PowerParameter *power_param = new (std::nothrow) PowerParameter();
if (power_param == nullptr) {
MS_LOG(ERROR) << "new PowerParameter failed.";
return nullptr;
}
power_param->op_parameter_.type_ = primitive->Type();
auto power = reinterpret_cast<mindspore::lite::PowerGrad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
power_param->power_ = power->GetPower();
power_param->scale_ = power->GetScale();
power_param->shift_ = power->GetShift();
return reinterpret_cast<OpParameter *>(power_param);
}
void PopulateTrainParameters() {
auto ppr = PopulateParameterRegistry::GetInstance();
ppr->AddPopulateParameterFunc(schema::PrimitiveType_ApplyMomentum, DefaultPopulateParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_BiasGrad, PopulateArithmetic);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_SoftmaxCrossEntropy, PopulateSoftmaxCrossEntropyParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_TupleGetItem, DefaultPopulateParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_Depend, DefaultPopulateParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, DefaultPopulateParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradFilter, PopulateConvolutionGradFilterParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradInput, PopulateConvolutionGradInputParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_PoolingGrad, PopulatePoolingGradParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter);
}
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_POPULATE_PARAMETER_H_
#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_POPULATE_PARAMETER_H_
#include "src/ops/primitive_c.h"
namespace mindspore::kernel {
void PopulateTrainParameters();
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_POPULATE_PARAMETER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/train_session.h"
#include <algorithm>
#include "utils/log_adapter.h"
#include "include/context.h"
#include "src/common/utils.h"
#include "mindspore/lite/src/ir/tensor.h"
#include "src/train/loss_kernel.h"
#include "src/train/train_populate_parameter.h"
#include "src/runtime/runtime_api.h"
#include "src/executor.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/fp32_grad/convolution.h"
namespace mindspore::session {
TrainSession::TrainSession() { kernel::PopulateTrainParameters(); }
void TrainSession::ReplaceOps() {
mindspore::lite::KernelRegistrar tmp(mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32,
mindspore::schema::PrimitiveType_Conv2D,
mindspore::kernel::CpuConvTrainFp32KernelCreator);
}
int TrainSession::CompileGraph(lite::Model *model) {
model_ = model;
ReplaceOps();
return LiteSession::CompileGraph(model);
}
void* TrainSession::ExportToBuf(void* buf, size_t *len) const {
// auto train_model_impl = (dynamic_cast<lite::train::TrainModelImpl*>(model_->model_impl()));
// return train_model_impl->ExportToBuf(buf, len);
return nullptr;
}
int TrainSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) {
auto ms_output_tensors = GetOutputs();
this->outputs_.clear();
for (auto ms_tensors : ms_output_tensors)
for (auto ms_tensor : ms_tensors.second)
this->outputs_.push_back((dynamic_cast<lite::tensor::LiteTensor*>(ms_tensor))->tensor());
if (train_mode_)
return LiteSession::RunGraph(before, after);
// object is expected to run only inference part of graph
// prepare a lit of kernels till the loss function -- temporary solution
std::vector<kernel::LiteKernel *> infference_kernels;
for (auto kernel : this->kernels_) {
if (dynamic_cast<const kernel::LossKernel*>(kernel) != nullptr)
break;
infference_kernels.push_back(kernel);
}
MS_EXCEPTION_IF_NULL(this->context_);
// TODO(Emir)
// SetMaxWokerNum(context_->thread_num_);
// context_->running_ = true;
lite::Executor executor;
if (before == nullptr && after == nullptr) {
return executor.Run(this->inputs_, this->outputs_, infference_kernels, this->context_->allocator.get());
} else {
return executor.Run(this->inputs_, this->outputs_, infference_kernels, this->context_->allocator.get(),
before, after);
}
}
void TrainSession::train() {
for (auto *kernel : kernels_) {
MS_ASSERT(nullptr != kernel);
kernel->train();
}
train_mode_ = true;
ext_output_map_.clear();
for (auto kernel : this->kernels_) {
if (dynamic_cast<const kernel::LossKernel*>(kernel) != nullptr) {
auto *ms_tensor = new lite::tensor::LiteTensor(kernel->out_tensors().at(0));
ext_output_map_[kernel->name()].emplace_back(ms_tensor);
}
}
}
void TrainSession::eval() {
for (auto *kernel : kernels_) {
MS_ASSERT(nullptr != kernel);
kernel->eval();
}
train_mode_ = false;
kernel::LiteKernel* last_kernel = nullptr;
// We should get in_kernels and then get all last kernels
ext_output_map_ = output_node_map_;
for (auto kernel : this->kernels_) {
if ((dynamic_cast<const kernel::LossKernel*>(kernel) != nullptr) &&
(last_kernel != nullptr)) {
auto *ms_tensor = new lite::tensor::LiteTensor(last_kernel->out_tensors().at(0));
ext_output_map_[last_kernel->name()].emplace_back(ms_tensor);
}
last_kernel = kernel;
}
}
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> TrainSession::GetOutputs() const {
return ext_output_map_;
}
std::vector<tensor::MSTensor *> TrainSession::GetOutputsByName(const std::string &name) const {
auto ret_vect = LiteSession::GetOutputsByNodeName(name); // TODO(emir): GetOutputsByTensorName?
if (ret_vect.size() > 0)
return ret_vect;
auto ret = ext_output_map_.find(name);
if (ret == ext_output_map_.end()) {
MS_LOG(WARNING) << "Node " << name << " is not an output node";
std::vector<mindspore::tensor::MSTensor *> empty_ret;
return empty_ret;
}
return ret->second;
}
} // namespace mindspore::session
......@@ -259,6 +259,10 @@ endif()
if (SUPPORT_TRAIN)
set(TEST_LITE_SRC
${TEST_LITE_SRC}
# ${LITE_DIR}/src/train/ops/train_ops.cc
${LITE_DIR}/src/train/train_populate_parameter.cc
${LITE_DIR}/src/train/train_session.cc
${LITE_DIR}/src/lite_session.cc
# ${SRC_DIR}/common/trans.cc
# ${SRC_DIR}/common/lite/trans_extends.cc
# ${SRC_DIR}/kernel/kernel_build_info.cc
......
......@@ -25,9 +25,10 @@
#include "mindspore/lite/src/ir/tensor.h"
#include "mindspore/lite/src/lite_kernel.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h"
#include "nnacl/fp32_grad/activation_grad.h"
namespace mindspore {
class TestActGradFp32 : public mindspore::CommonTest {
class TestActGradFp32 : public mindspore::CommonTest {
public:
TestActGradFp32() {}
};
......@@ -41,13 +42,14 @@ TEST_F(TestActGradFp32, ReluGradFp32) {
size_t input_size;
std::string input_path = "./test_data/activationGrad/relu_y_50.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
EXPECT_EQ(input_size, output_data_size * sizeof(float));
std::string yt_path = "./test_data/activationGrad/relu_yt_50.bin";
auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size));
EXPECT_EQ(input_size, output_data_size * sizeof(float));
auto output_data = new float[output_data_size];
// warm up loop
for (int i = 0; i < 3; i++) {
ReluGrad(yt_data, input_data, 50, output_data);
ReluGrad(yt_data, input_data, output_data_size, output_data);
}
int loop_count = 100;
......@@ -72,9 +74,9 @@ TEST_F(TestActGradFp32, ReluGradFp32) {
EXPECT_EQ(res, 0);
delete input_data;
delete[] input_data;
delete[] output_data;
delete yt_data;
delete[] yt_data;
MS_LOG(INFO) << "ReluGradFp32 passed";
}
......@@ -118,9 +120,9 @@ TEST_F(TestActGradFp32, Relu6GradFp32) {
EXPECT_EQ(res, 0);
delete input_data;
delete[] input_data;
delete[] output_data;
delete yt_data;
delete[] yt_data;
MS_LOG(INFO) << "Relu6GradFp32 passed";
}
......@@ -164,9 +166,9 @@ TEST_F(TestActGradFp32, LReluGradFp32) {
EXPECT_EQ(res, 0);
delete input_data;
delete[] input_data;
delete[] output_data;
delete yt_data;
delete[] yt_data;
MS_LOG(INFO) << "LReluGradFp32 passed";
}
......@@ -211,9 +213,9 @@ TEST_F(TestActGradFp32, SigmoidGradFp32) {
EXPECT_EQ(res, 0);
// lite::CompareOutput(output_data, output_path);
delete input_data;
delete[] input_data;
delete[] output_data;
delete yt_data;
delete[] yt_data;
MS_LOG(INFO) << "SigmoidGradFp32 passed";
}
......@@ -257,9 +259,9 @@ TEST_F(TestActGradFp32, tanhGradFp32) {
EXPECT_EQ(res, 0);
delete input_data;
delete[] input_data;
delete[] output_data;
delete yt_data;
delete[] yt_data;
MS_LOG(INFO) << "TanhGradFp32 passed";
}
......@@ -267,24 +269,25 @@ TEST_F(TestActGradFp32, hswishGradFp32) {
// runtime part
printf("Calculating runtime cost...\n");
uint64_t time_avg = 0;
size_t output_data_size = 50;
const size_t output_data_size = 10;
size_t input_size;
std::string input_path = "./test_data/activationGrad/hswish_x_50.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
EXPECT_EQ(input_size, output_data_size * sizeof(float));
std::string yt_path = "./test_data/activationGrad/hswish_yt_50.bin";
auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size));
EXPECT_EQ(input_size, output_data_size * sizeof(float));
auto output_data = new float[output_data_size];
// warm up loop
for (int i = 0; i < 3; i++) {
HSwishGrad(yt_data, input_data, 50, output_data);
HSwishGrad(yt_data, input_data, static_cast<int>(output_data_size), output_data);
}
int loop_count = 100;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
HSwishGrad(yt_data, input_data, 50, output_data);
HSwishGrad(yt_data, input_data, output_data_size, output_data);
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
......@@ -292,7 +295,7 @@ TEST_F(TestActGradFp32, hswishGradFp32) {
printf("single thread running time : %f ms\n", time_avg / 1000.0f);
printf("==================output data=================\n");
for (int i = 0; i < 20; i++) {
for (int i = 0; i < std::min(output_data_size, 20UL); i++) {
std::cout << output_data[i] << " ,";
}
std::cout << std::endl;
......@@ -302,9 +305,9 @@ TEST_F(TestActGradFp32, hswishGradFp32) {
EXPECT_EQ(res, 0);
delete input_data;
delete[] input_data;
delete[] output_data;
delete yt_data;
delete[] yt_data;
MS_LOG(INFO) << "hswishGradFp32 passed";
}
......
......@@ -106,9 +106,14 @@ TEST_F(TestArithmeticGradFp32, TestAddGradFp32) {
std::string dx2_path = "./test_data/operators/arithmetic_fp32_1_dx2_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
// delete all_tensors;
// delete param;
delete kernel_obj;
MS_LOG(INFO) << "TestAddGradFp32 passed";
}
......@@ -137,9 +142,14 @@ TEST_F(TestArithmeticGradFp32, TestAddGrad2Fp32) {
std::string dx2_path = "./test_data/operators/arithmetic_fp32_1_dx2_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
// for (int i = 0; i < 5; i++) delete all_tensors[i]; //TODO tensor data is unique pointer
// delete param;
delete kernel_obj;
MS_LOG(INFO) << "TestAddGrad2Fp32 passed";
}
......@@ -169,8 +179,14 @@ TEST_F(TestArithmeticGradFp32, TestAddGrad3Fp32) {
std::string dx2_path = "./test_data/operators/arithmetic_fp32_8_dx1_5_4_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
// for (int i = 0; i < 5; i++) delete all_tensors[i];
// delete param;
delete kernel_obj;
MS_LOG(INFO) << "TestAddGrad3Fp32 passed";
}
......@@ -200,8 +216,14 @@ TEST_F(TestArithmeticGradFp32, TestSubGradFp32) {
std::string dx2_path = "./test_data/operators/arithmetic_fp32_2_dx2_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
// for (int i = 0; i < 5; i++) delete all_tensors[i];
// delete param;
delete kernel_obj;
MS_LOG(INFO) << "TestSubGradFp32 passed";
}
......@@ -231,8 +253,12 @@ TEST_F(TestArithmeticGradFp32, TestSubGrad2Fp32) {
std::string dx2_path = "./test_data/operators/arithmetic_fp32_3_dx2_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
delete kernel_obj;
MS_LOG(INFO) << "TestSubGrad2Fp32 passed";
}
......@@ -271,9 +297,13 @@ TEST_F(TestArithmeticGradFp32, TestMulGradFp32) {
std::string dx2_path = "./test_data/operators/arithmetic_fp32_4_dx2_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
delete kernel_obj;
// delete param;
MS_LOG(INFO) << "TestMulGradFp32 passed";
}
......@@ -302,9 +332,14 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad2Fp32) {
std::string dx2_path = "./test_data/operators/arithmetic_fp32_4_dx2_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
// for (int i = 0; i < 5; i++) delete all_tensors[i];
// delete param;
delete kernel_obj;
MS_LOG(INFO) << "TestMulGrad2Fp32 passed";
}
......@@ -333,9 +368,14 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad3Fp32) {
std::string dx2_path = "./test_data/operators/arithmetic_fp32_9_dx2_5_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
// for (int i = 0; i < 5; i++) delete all_tensors[i];
// delete param;
delete kernel_obj;
MS_LOG(INFO) << "TestMulGrad3Fp32 passed";
}
......@@ -364,9 +404,14 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad4Fp32) {
std::string dx2_path = "./test_data/operators/arithmetic_fp32_9_dx2_5_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
// for (int i = 0; i < 5; i++) delete all_tensors[i];
// delete param;
delete kernel_obj;
MS_LOG(INFO) << "TestMulGrad4Fp32 passed";
}
......@@ -395,9 +440,14 @@ TEST_F(TestArithmeticGradFp32, TestDivGradFp32) {
std::string dx2_path = "./test_data/operators/arithmetic_fp32_5_dx2_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
// for (int i = 0; i < 5; i++) delete all_tensors[i];
delete kernel_obj;
// delete param;
MS_LOG(INFO) << "TestDivGradFp32 passed";
}
......@@ -427,8 +477,14 @@ TEST_F(TestArithmeticGradFp32, TestDivGrad2Fp32) {
std::string output_path = "./test_data/operators/arithmetic_fp32_6_dx1_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
// for (int i = 0; i < 5; i++) delete all_tensors[i];
// delete param;
delete kernel_obj;
MS_LOG(INFO) << "TestDivGrad2Fp32 passed";
}
......@@ -457,9 +513,14 @@ TEST_F(TestArithmeticGradFp32, TestDivGrad3Fp32) {
std::string output_path = "./test_data/operators/arithmetic_fp32_10_dx2_5_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
// for (int i = 0; i < 5; i++) delete all_tensors[i];
// delete param;
delete kernel_obj;
MS_LOG(INFO) << "TestDivGrad3Fp32 passed";
}
......@@ -488,9 +549,12 @@ TEST_F(TestArithmeticGradFp32, Test3DDivGrad2Fp32) {
std::string output_path = "./test_data/operators/arithmetic_fp32_7_dx2_1_1_6.bin";
EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path));
for (int i = 0; i < 5; i++) delete all_tensors[i];
delete param;
for (auto tensor : all_tensors) {
delete[] reinterpret_cast<float *>(tensor->Data());
tensor->SetData(nullptr);
delete tensor;
}
delete kernel_obj;
MS_LOG(INFO) << "TestDivGrad2Fp32 passed";
}
......
......@@ -18,8 +18,8 @@
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "src/runtime/kernel/arm/fp32_grad/bias_grad.h"
#include "src/kernel_registry.h"
namespace mindspore {
......@@ -40,9 +40,8 @@ TEST_F(TestBiasGradFp32, BiasGradFp32) {
dy_tensor.SetData(input_data);
std::vector<lite::tensor::Tensor *> inputs = {&dy_tensor};
auto output_data = new float[7];
std::vector<int> dim_dw({7});
std::vector<int> dim_dw = {7};
lite::tensor::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw);
dw_tensor.SetData(output_data);
std::vector<lite::tensor::Tensor *> outputs = {&dw_tensor};
......@@ -62,9 +61,12 @@ TEST_F(TestBiasGradFp32, BiasGradFp32) {
std::string output_path = "./test_data/operators/biasgradfp32_1_db_7.bin";
lite::CompareOutput(output_data, output_path);
// delete input_data;
// delete[] output_data;
delete bias_param;
delete [] input_data;
delete[] output_data;
// delete bias_param;
dy_tensor.SetData(nullptr);
dw_tensor.SetData(nullptr);
delete kernel_obj;
MS_LOG(INFO) << "BiasGradFp32 passed";
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "src/common/file_utils_ext.h"
#include "src/runtime/kernel/arm/fp32_grad/bn_grad.h"
#include "nnacl/fp32_grad/batch_norm.h"
#include "src/kernel_registry.h"
#
namespace mindspore {
class TestBNGradFp32 : public mindspore::CommonTest {
public:
TestBNGradFp32() {}
lite::tensor::Tensor *CreateInTensor(std::string file_name, std::vector<int> dim);
};
lite::tensor::Tensor *TestBNGradFp32::CreateInTensor(std::string file_name, std::vector<int> dim) {
size_t input_size = 0;
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_name.c_str(), &input_size));
auto tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, dim);
tensor->SetData(input_data);
EXPECT_EQ(input_size, tensor->Size());
return tensor;
}
TEST_F(TestBNGradFp32, BNGradFp32) {
// prepare stage
auto bn_param = new BNGradParameter();
bn_param->epsilon_ = 0.00001;
bn_param->momentum_ = 0.1;
const int batch = 2;
const int channels = 3;
const int height = 4;
const int width = 5;
auto dy_tensor = CreateInTensor("./test_data/bngrad/dy_2_4_5_3.bin", {batch, height, width, channels});
auto x_tensor = CreateInTensor("./test_data/bngrad/input_x_2_4_5_3.bin", {batch, height, width, channels});
auto scale_tensor = CreateInTensor("./test_data/bngrad/scale_3.bin", {1, 1, 1, channels});
auto mean_tensor = CreateInTensor("./test_data/bngrad/save_mean_3.bin", {1, 1, 1, channels});
auto var_tensor = CreateInTensor("././test_data/bngrad/save_var_3.bin", {1, 1, 1, channels});
// prepare output tensors
lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, {batch, height, width, channels});
dx_tensor.MallocData();
lite::tensor::Tensor dscale_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels});
dscale_tensor.MallocData();
lite::tensor::Tensor dbias_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels});
dbias_tensor.MallocData();
std::vector<lite::tensor::Tensor *> inputs = {dy_tensor, x_tensor, scale_tensor, mean_tensor, var_tensor};
std::vector<lite::tensor::Tensor *> outputs = {&dx_tensor, &dscale_tensor, &dbias_tensor};
kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_BNGrad};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(bn_param), NULL, desc, nullptr);
for (int i = 0; i < 3; i++) {
kernel_obj->Run();
}
int loop_count = 100;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
kernel_obj->Run();
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
auto time_avg = cost / loop_count;
std::cout << "single thread running time : " << time_avg << "us\n";
std::cout << "==========dx==========\n";
auto dx = reinterpret_cast<float *>(outputs[0]->Data());
for (int i = 0; i < 7; i++) std::cout << dx[i] << " ";
std::cout << "\n=======dscale=======\n";
auto dscale = reinterpret_cast<float *>(outputs[1]->Data());
for (int i = 0; i < channels; i++) std::cout << dscale[i] << " ";
std::cout << "\n";
int res = mindspore::lite::CompareRelativeOutput(dscale, "./test_data/bngrad/output_dscale_3.bin");
EXPECT_EQ(res, 0);
std::cout << "==========dbias==========\n";
auto dbias = reinterpret_cast<float *>(outputs[2]->Data());
for (int i = 0; i < 3; i++) std::cout << dbias[i] << " ";
std::cout << "\n";
res = mindspore::lite::CompareRelativeOutput(dscale, "./test_data/bngrad/output_dscale_3.bin");
EXPECT_EQ(res, 0);
for (auto v : inputs) {
delete[] reinterpret_cast<float *>(v->Data());
v->SetData(nullptr);
// delete v;
}
delete kernel_obj;
MS_LOG(INFO) << "BNGradFp32 passed";
}
} // namespace mindspore
......@@ -21,6 +21,7 @@
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "src/common/file_utils_ext.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h"
#include "mindspore/lite/nnacl/conv_parameter.h"
......@@ -130,11 +131,14 @@ TEST_F(TestConvolutionGradFp32, ConvFp32FilterGrad) {
EXPECT_EQ(res, 0);
// delete input_data;
// delete dy_data;
// delete [] dw_data;
delete [] input_data;
delete [] dy_data;
delete [] dw_data;
delete kernel;
delete conv_param;
// delete conv_param;
dw_tensor.SetData(nullptr);
x_tensor.SetData(nullptr);
dy_tensor.SetData(nullptr);
MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed";
}
......@@ -193,9 +197,15 @@ TEST_F(TestConvolutionGradFp32, ConvFp32InputGrad) {
std::string output_path = "./test_data/conv/convfp32_dx_1_28_28_3.bin";
auto res = lite::CompareRelativeOutput(dx_data, output_path);
EXPECT_EQ(res, 0);
delete [] dx_data;
delete [] w_data;
delete [] dy_data;
w_tensor.SetData(nullptr);
dy_tensor.SetData(nullptr);
dx_tensor.SetData(nullptr);
delete kernel;
delete conv_param;
// delete conv_param;
MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed";
}
......@@ -254,11 +264,14 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupFilterGrad) {
auto res = lite::CompareRelativeOutput(dw_data, output_path);
EXPECT_EQ(res, 0);
// delete input_data;
// delete dy_data;
// delete [] dw_data;
delete [] input_data;
delete [] dy_data;
delete [] dw_data;
dw_tensor.SetData(nullptr);
x_tensor.SetData(nullptr);
dy_tensor.SetData(nullptr);
delete kernel;
delete conv_param;
// delete conv_param;
MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed";
}
......@@ -317,9 +330,15 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupInputGrad) {
std::string output_path = "./test_data/conv/convfp32_dx_g3_1_28_28_3.bin";
auto res = lite::CompareRelativeOutput(dx_data, output_path);
EXPECT_EQ(res, 0);
delete [] dx_data;
delete [] w_data;
delete [] dy_data;
dx_tensor.SetData(nullptr);
w_tensor.SetData(nullptr);
dy_tensor.SetData(nullptr);
delete kernel;
delete conv_param;
// delete conv_param;
MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed";
}
......@@ -378,11 +397,14 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationFilterGrad) {
std::string output_path = "./test_data/conv/convfp32_dw_g3_d2_18_3_3_3.bin";
auto res = lite::CompareRelativeOutput(dw_data, output_path);
EXPECT_EQ(res, 0);
// delete input_data;
// delete dy_data;
// delete [] dw_data;
delete [] input_data;
delete [] dy_data;
delete [] dw_data;
dw_tensor.SetData(nullptr);
dy_tensor.SetData(nullptr);
x_tensor.SetData(nullptr);
delete kernel;
delete conv_param;
// delete conv_param;
MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed";
}
......@@ -441,80 +463,93 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationInputGrad) {
std::string output_path = "./test_data/conv/convfp32_dx_g3_d2_1_28_28_3.bin";
auto res = lite::CompareRelativeOutput(dx_data, output_path);
EXPECT_EQ(res, 0);
delete [] dx_data;
delete [] w_data;
delete [] dy_data;
dx_tensor.SetData(nullptr);
dy_tensor.SetData(nullptr);
w_tensor.SetData(nullptr);
delete kernel;
delete conv_param;
// delete conv_param;
MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed";
}
// TEST_F(TestConvolutionGradFp32, ConvGroupDilation) {
// // prepare stage
// auto conv_param = new ConvParameter();
// InitConvParamGroup3Dilation2FP32(conv_param);
// size_t x_size;
// std::string x_path = "./test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin";
// auto x_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(x_path.c_str(), &x_size));
// std::vector<int> dim_x({1, 28, 28, 3});
// tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x);
// x_tensor.SetData(x_data);
// size_t w_size;
// std::string w_path = "./test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin";
// auto w_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(w_path.c_str(), &w_size));
// std::vector<int> dim_w({18, 3, 3, 1});
// tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_w);
// w_tensor.SetData(w_data);
// size_t output_data_size =
// conv_param->output_batch_ * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
// auto y_data = new float[output_data_size];
// std::vector<int> dim_y({1, 26, 26, 18});
// tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y);
// y_tensor.SetData(y_data);
// std::vector<tensor::Tensor *> inputs = {&x_tensor, &w_tensor};
// std::vector<tensor::Tensor *> outputs = {&y_tensor};
// // runtime part
// printf("Calculating runtime cost...\n");
// uint64_t time_avg = 0;
// lite::Context context;
// ;
// context.deviceCtx.type = lite::DT_CPU;
// context.threadNum = 1;
// kernel::KernelKey desc = {kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Conv2D};
// auto creator = lite::KernelRegistry::GetInstance()->GetKernelCreator(desc);
// auto kernel = creator(inputs, outputs, (OpParameter *)conv_param, &context, desc);
// kernel->train();
// EXPECT_EQ(kernel->is_train(), 1);
// // warm up loop
// for (int i = 0; i < 3; i++) {
// kernel->Run();
// }
// int loop_count = 100;
// auto time_start = mindspore::lite::GetTimeUs();
// for (int i = 0; i < loop_count; i++) {
// kernel->Run();
// }
// auto time_end = mindspore::lite::GetTimeUs();
// auto cost = time_end - time_start;
// time_avg = cost / loop_count;
// printf("single thread running time : %f ms\n", time_avg / 1000.0f);
// std::string output_path = "./test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin";
// auto res = lite::CompareRelativeOutput(y_data, output_path);
// EXPECT_EQ(res, 0);
// delete kernel;
// delete conv_param;
// MS_LOG(INFO) << "TestConvolutionFp32 Filter Grad passed";
// }
TEST_F(TestConvolutionGradFp32, ConvGroupDilation) {
// prepare stage
auto conv_param = new ConvParameter();
InitConvParamGroup3Dilation2FP32(conv_param);
size_t x_size;
std::string x_path = "./test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin";
auto x_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(x_path.c_str(), &x_size));
std::vector<int> dim_x({1, 28, 28, 3});
lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x);
x_tensor.SetData(x_data);
size_t w_size;
std::string w_path = "./test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin";
auto w_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(w_path.c_str(), &w_size));
std::vector<int> dim_w({18, 3, 3, 1});
lite::tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_w);
w_tensor.SetData(w_data);
size_t output_data_size =
conv_param->output_batch_ * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
auto y_data = new float[output_data_size];
std::vector<int> dim_y({1, 26, 26, 18});
lite::tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y);
y_tensor.SetData(y_data);
std::vector<lite::tensor::Tensor *> inputs = {&x_tensor, &w_tensor};
std::vector<lite::tensor::Tensor *> outputs = {&y_tensor};
// runtime part
printf("Calculating runtime cost...\n");
uint64_t time_avg = 0;
lite::Context context;
context.device_ctx_.type = lite::DT_CPU;
context.thread_num_ = 1;
auto *kernel = new mindspore::kernel::ConvolutionTrainCPUKernel(reinterpret_cast<OpParameter *>(conv_param),
inputs, outputs, &context, 0);
kernel->Init();
// kernel::KernelKey desc = {kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Conv2D};
// auto creator = lite::KernelRegistry::GetInstance()->GetKernelCreator(desc);
// auto kernel = creator(inputs, outputs, (OpParameter *)conv_param, &context, desc);
kernel->train();
EXPECT_EQ(kernel->is_train(), 1);
// warm up loop
for (int i = 0; i < 3; i++) {
kernel->Run();
}
int loop_count = 100;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
kernel->Run();
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
time_avg = cost / loop_count;
printf("single thread running time : %f ms\n", time_avg / 1000.0f);
std::string output_path = "./test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin";
auto res = lite::CompareRelativeOutput(y_data, output_path);
EXPECT_EQ(res, 0);
delete [] y_data;
delete [] x_data;
delete [] w_data;
x_tensor.SetData(nullptr);
y_tensor.SetData(nullptr);
w_tensor.SetData(nullptr);
delete kernel;
MS_LOG(INFO) << "TestConvolutionFp32 Filter Grad passed";
}
} // namespace mindspore
......@@ -40,7 +40,7 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) {
y_tensor.SetData(input_data);
std::string label_path = "./test_data/operators/sce_fp32_1_l_6.bin";
auto ll_labels = reinterpret_cast<int64 *>(mindspore::lite::ReadFile(label_path.c_str(), &input_size));
auto ll_labels = reinterpret_cast<int64_t *>(mindspore::lite::ReadFile(label_path.c_str(), &input_size));
auto labels = new int[6];
for (int i = 0; i < 6; i++) labels[i] = static_cast<int>(ll_labels[i]);
......@@ -57,7 +57,7 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) {
auto grad = new float[24];
lite::tensor::Tensor grad_tensor(TypeId::kNumberTypeFloat32, dim_y);
grad_tensor.SetData(grad);
std::vector<lite::tensor::Tensor *> outputs = {&grad_tensor, &loss_tensor};
std::vector<lite::tensor::Tensor *> outputs = {&loss_tensor, &grad_tensor};
kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftmaxCrossEntropy};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
......
V_?Kϧ࿅>J?/="m?Luj@!U$?f=?e[?Wھ m ? eO?}4?B<Keڽp~|>?7E :?JͿ̬> ? ~?ϫN1?> HV|ʾ={IU?xvW>[$?]4Bu 4@+?z>uB?=|e >M>>?}0?> @<?v?vZ?zſ@.ο8B?o Ծq"mn?k>=">: @<>+R
b6.?i?v?`j6R~]?JU6sG?M% ?h>ȿ G½?>ӓ'6?@2/VK5T>X]?[?v_ؿj?p?\l?.l=b?
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册