diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index d2d7b806e87e28c403b2215dd8c346a7db5ce374..286077299d41bd75828d744c6923f0f61f8d1c37 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -179,6 +179,7 @@ union PrimitiveType { Conv2DGradInput, PoolingGrad, BNGrad, + BNGradInput, ApplyMomentum, BiasGrad, SoftmaxCrossEntropy, diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 3ad035c0f7741e88fbd135bda8d2c9f40bf1fde3..2e13afa463605b2f5f099cc9c3299a9285308d9d 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -398,7 +398,10 @@ table BNGrad { eps : float; momentum: float; } - +table BNGradInput { + eps : float; + momentum: float; +} table Scale { axis: int; } diff --git a/mindspore/lite/src/ops/activation_grad.cc b/mindspore/lite/src/ops/activation_grad.cc index 0554e53040169e2a82a63a2e594180038f6cc2ff..0efed3a6ab3f49c0741a70ef2badd0da38824c09 100644 --- a/mindspore/lite/src/ops/activation_grad.cc +++ b/mindspore/lite/src/ops/activation_grad.cc @@ -25,6 +25,36 @@ void ActivationGrad::SetType(int type) { this->primitive_->value.AsActivationGrad()->type = (schema::ActivationType)type; } void ActivationGrad::SetAlpha(float alpha) { this->primitive_->value.AsActivationGrad()->alpha = alpha; } +int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_ActivationGrad; + } + if (this->primitive_->value.type != schema::PrimitiveType_ActivationGrad) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + auto attr = std::make_unique(); + if (prim.name() == "ReLU") { + attr->type = schema::ActivationType_RELU; + } else if (prim.name() == "Sigmoid") { + attr->type = schema::ActivationType_SIGMOID; + } else if (prim.name() == "ReLU6") { + attr->type = schema::ActivationType_RELU6; + } + auto alpha = GetValue(prim.GetAttr("alpha")); + attr->alpha = alpha; + this->primitive_->value.value = attr.release(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + return RET_OK; +} #else int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/activation_grad.h b/mindspore/lite/src/ops/activation_grad.h index 463043bd7b49d8567b9442a06f786fca9ae59520..907c5c589e68b97d015aaa3ea1c21941e813cc82 100644 --- a/mindspore/lite/src/ops/activation_grad.h +++ b/mindspore/lite/src/ops/activation_grad.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" @@ -33,6 +34,7 @@ class ActivationGrad : public PrimitiveC { explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetType(int type); void SetAlpha(float alpha); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else ActivationGrad() = default; diff --git a/mindspore/lite/src/ops/bias_grad.cc b/mindspore/lite/src/ops/bias_grad.cc index d571ed3d0911f0575e914c12f5bb4826fbd08066..c3c4ac899bc1095636845803f5c0ea2fa49f9d46 100644 --- a/mindspore/lite/src/ops/bias_grad.cc +++ b/mindspore/lite/src/ops/bias_grad.cc @@ -22,7 +22,34 @@ namespace lite { std::vector BiasGrad::GetAxis() const { return this->primitive_->value.AsBiasGrad()->axis; } void BiasGrad::SetAxis(const std::vector &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; } - +int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_BiasGrad; + } + if (this->primitive_->value.type != schema::PrimitiveType_BiasGrad) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::BiasGradT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + attr->axis = GetValue>(prim.GetAttr("axis")); + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else int BiasGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/bias_grad.h b/mindspore/lite/src/ops/bias_grad.h index 37532961744437d5f2b6b9a82eaa2136b5832c6b..a96fde9aeb6233a7e15f49062572eae75ee2d183 100644 --- a/mindspore/lite/src/ops/bias_grad.h +++ b/mindspore/lite/src/ops/bias_grad.h @@ -33,7 +33,7 @@ class BiasGrad : public PrimitiveC { BiasGrad() = default; explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetAxis(const std::vector &axis); - + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else BiasGrad() = default; diff --git a/mindspore/lite/src/ops/bn_grad.h b/mindspore/lite/src/ops/bn_grad.h index e346593a53599448351b72952de393aa53001c70..0d09639bfef04ae8daf00cd60634535a31ca8eca 100644 --- a/mindspore/lite/src/ops/bn_grad.h +++ b/mindspore/lite/src/ops/bn_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ -#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ +#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_ +#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_ #include #include diff --git a/mindspore/lite/src/ops/bn_grad_input.cc b/mindspore/lite/src/ops/bn_grad_input.cc new file mode 100644 index 0000000000000000000000000000000000000000..d243764979fbe2af644283dbf8a61dcca9ecea3e --- /dev/null +++ b/mindspore/lite/src/ops/bn_grad_input.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/bn_grad_input.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +float BNGradInput::GetEps() const { return this->primitive_->value.AsBNGradInput()->eps; } +float BNGradInput::GetMomentum() const { return this->primitive_->value.AsBNGradInput()->momentum; } + +void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->eps = eps; } +void BNGradInput::SetMomentum(float momentum) { this->primitive_->value.AsBNGradInput()->momentum = momentum; } +int BNGradInput::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_BNGradInput; + } + if (this->primitive_->value.type != schema::PrimitiveType_BNGradInput) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::BNGradInputT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + attr->eps = GetValue(prim.GetAttr("eps")); + attr->momentum = GetValue(prim.GetAttr("momentum")); + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} +#else +int BNGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_BNGradInput(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_BNGradInputInput return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateBNGradInput(*fbb, attr->eps(), attr->momentum()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGradInput, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); } +float BNGradInput::GetMomentum() const { return this->primitive_->value_as_BNGradInput()->momentum(); } + +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/bn_grad_input.h b/mindspore/lite/src/ops/bn_grad_input.h new file mode 100644 index 0000000000000000000000000000000000000000..52645f83f5019d16b916ba56f9d1308fff05e77e --- /dev/null +++ b/mindspore/lite/src/ops/bn_grad_input.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ + +#include +#include +#include +#include "ir/dtype/type_id.h" +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class BNGradInput : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(BNGradInput, PrimitiveC); + BNGradInput() = default; + explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + void SetEps(float eps); + void SetMomentum(float momentum); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + BNGradInput() = default; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + float GetEps() const; + float GetMomentum() const; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.cc b/mindspore/lite/src/ops/conv2d_grad_filter.cc index 32866852944c312968c5fa73f94f32d8576cfba0..37d05fe3f333b43eaf4addff3c7541df71598acc 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.cc +++ b/mindspore/lite/src/ops/conv2d_grad_filter.cc @@ -66,7 +66,133 @@ void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive_->value.AsCon void Conv2DGradFilter::SetActivationType(int activation_type) { this->primitive_->value.AsConv2DGradFilter()->activationType = (schema::ActivationType)activation_type; } +void Conv2DGradFilter::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, + const std::vector &inputs) { + auto attr = std::make_unique(); + auto format = GetValue(prim.GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(prim.GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(prim.GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(prim.GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + auto pad_mode = GetValue(prim.GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + if (prim.GetAttr("activation_name") != nullptr) { + std::string activate_name = GetValue(prim.GetAttr("activation_name")); + attr->activationType = kActivationTypeMap[activate_name]; + } else { + attr->activationType = schema::ActivationType_NO_ACTIVATION; + } + + int channel_mutiplier = 1; + if (prim.GetAttr("channel_mutiplier") != nullptr) { + channel_mutiplier = GetValue(prim.GetAttr("channel_multiplier")); + } + attr->channelMultiplier = channel_mutiplier; + + primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + primitive->value.value = attr.release(); +} + +void Conv2DGradFilter::PopulaterConv2DSingleGroup(const Primitive &prim, + schema::PrimitiveT *primitive, const int &group) { + auto attr = std::make_unique(); + attr->group = group; + auto format = GetValue(prim.GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(prim.GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(prim.GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(prim.GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + attr->channelOut = GetValue(prim.GetAttr("out_channel")); + + auto pad_mode = GetValue(prim.GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + if (prim.GetAttr("activation_name") != nullptr) { + std::string activate_name = GetValue(prim.GetAttr("activation_name")); + attr->activationType = kActivationTypeMap[activate_name]; + } else { + attr->activationType = schema::ActivationType_NO_ACTIVATION; + } + + primitive->value.type = schema::PrimitiveType_Conv2D; + primitive->value.value = attr.release(); +} +int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Conv2DGradFilter; + } + if (this->primitive_->value.type != schema::PrimitiveType_Conv2DGradFilter) { + MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; + return RET_ERROR; + } + int group = GetValue(prim.GetAttr("group")); + if (group > 1) { + PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); + } else { + PopulaterConv2DSingleGroup(prim, this->primitive_, group); + } + return RET_OK; +} #else int Conv2DGradFilter::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.h b/mindspore/lite/src/ops/conv2d_grad_filter.h index 96c189d3be0bb858c98681835006db05f725dac8..46917b5413a3ec15fc17e3eb45c9c35185f1f1b9 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.h +++ b/mindspore/lite/src/ops/conv2d_grad_filter.h @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" @@ -48,6 +50,10 @@ class Conv2DGradFilter : public PrimitiveC { void SetDilateH(int dilate_h); void SetHasBias(bool has_bias); void SetActivationType(int activation_type); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; + void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, + const std::vector &inputs); + void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); #else Conv2DGradFilter() = default; diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index 1c078b9cd67c3212ffa432f2116c39eeafec1c73..85a5156e973ef467143976b68c5e665a0c1327ca 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -64,7 +64,133 @@ void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive_->value.AsConv void Conv2DGradInput::SetActivationType(int activation_type) { this->primitive_->value.AsConv2DGradInput()->activationType = (schema::ActivationType)activation_type; } +void Conv2DGradInput::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, + const std::vector &inputs) { + auto attr = std::make_unique(); + auto format = GetValue(prim.GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(prim.GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(prim.GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(prim.GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + auto pad_mode = GetValue(prim.GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + if (prim.GetAttr("activation_name") != nullptr) { + std::string activate_name = GetValue(prim.GetAttr("activation_name")); + attr->activationType = kActivationTypeMap[activate_name]; + } else { + attr->activationType = schema::ActivationType_NO_ACTIVATION; + } + + int channel_mutiplier = 1; + if (prim.GetAttr("channel_mutiplier") != nullptr) { + channel_mutiplier = GetValue(prim.GetAttr("channel_multiplier")); + } + attr->channelMultiplier = channel_mutiplier; + + primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + primitive->value.value = attr.release(); +} + +void Conv2DGradInput::PopulaterConv2DSingleGroup(const Primitive &prim, + schema::PrimitiveT *primitive, const int &group) { + auto attr = std::make_unique(); + attr->group = group; + auto format = GetValue(prim.GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(prim.GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(prim.GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(prim.GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + attr->channelOut = GetValue(prim.GetAttr("out_channel")); + + auto pad_mode = GetValue(prim.GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + if (prim.GetAttr("activation_name") != nullptr) { + std::string activate_name = GetValue(prim.GetAttr("activation_name")); + attr->activationType = kActivationTypeMap[activate_name]; + } else { + attr->activationType = schema::ActivationType_NO_ACTIVATION; + } + + primitive->value.type = schema::PrimitiveType_Conv2D; + primitive->value.value = attr.release(); +} +int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Conv2DGradInput; + } + if (this->primitive_->value.type != schema::PrimitiveType_Conv2DGradInput) { + MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; + return RET_ERROR; + } + int group = GetValue(prim.GetAttr("group")); + if (group > 1) { + PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); + } else { + PopulaterConv2DSingleGroup(prim, this->primitive_, group); + } + return RET_OK; +} #else int Conv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/conv2d_grad_input.h b/mindspore/lite/src/ops/conv2d_grad_input.h index d6dab8522b0c0036bdf8b373819a1d10b38d56ba..4656addee34a59a940d362451e0bfd2e1ff76719 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.h +++ b/mindspore/lite/src/ops/conv2d_grad_input.h @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" @@ -48,6 +50,10 @@ class Conv2DGradInput : public PrimitiveC { void SetDilateH(int dilate_h); void SetHasBias(bool has_bias); void SetActivationType(int activation_type); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; + void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, + const std::vector &inputs); + void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); #else Conv2DGradInput() = default; diff --git a/mindspore/lite/src/ops/pooling_grad.cc b/mindspore/lite/src/ops/pooling_grad.cc index 17bc053c49e0e7f61320c70debd33e4f1b5dfa6e..7eba1875a11e5eac5c9ba11f1640d8699edfee03 100644 --- a/mindspore/lite/src/ops/pooling_grad.cc +++ b/mindspore/lite/src/ops/pooling_grad.cc @@ -52,7 +52,64 @@ void PoolingGrad::SetPadRight(int pad_right) { this->primitive_->value.AsPooling void PoolingGrad::SetRoundMode(int round_mode) { this->primitive_->value.AsPoolingGrad()->roundMode = (schema::RoundMode)round_mode; } +int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_PoolingGrad; + } + if (this->primitive_->value.type != schema::PrimitiveType_PoolingGrad) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::PoolingGradT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + + auto format = GetValue(prim.GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + if (prim.instance_name() == "MaxPool") { + attr->poolingMode = schema::PoolMode_MAX_POOLING; + } else if (prim.instance_name() == "MeanPool") { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + } + auto pad_mode = GetValue(prim.GetAttr("padding")); + if (pad_mode == "VALID") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "SAME") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + auto kernel_size = GetValue>(prim.GetAttr("ksize")); + attr->windowH = kernel_size[2]; + attr->windowW = kernel_size[3]; + + auto stride = GetValue>(prim.GetAttr("strides")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else int PoolingGrad::GetFormat() const { return this->primitive_->value_as_PoolingGrad()->format(); } diff --git a/mindspore/lite/src/ops/pooling_grad.h b/mindspore/lite/src/ops/pooling_grad.h index dbafdb9254405394ef5d30eda060e238ebba6b79..1fe47eb327aa604a00d1f869f32830e7c6c46ab5 100644 --- a/mindspore/lite/src/ops/pooling_grad.h +++ b/mindspore/lite/src/ops/pooling_grad.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" @@ -44,6 +45,7 @@ class PoolingGrad : public PrimitiveC { void SetPadLeft(int pad_left); void SetPadRight(int pad_right); void SetRoundMode(int round_mode); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else PoolingGrad() = default; diff --git a/mindspore/lite/src/ops/power_grad.cc b/mindspore/lite/src/ops/power_grad.cc index ba10623dece228e35a20bdb0159b0ae9826c41a1..5529e1055a0b3695e65eded0c9d802677eabfca8 100644 --- a/mindspore/lite/src/ops/power_grad.cc +++ b/mindspore/lite/src/ops/power_grad.cc @@ -26,7 +26,36 @@ float PowerGrad::GetShift() const { return this->primitive_->value.AsPowerGrad() void PowerGrad::SetPower(float power) { this->primitive_->value.AsPowerGrad()->power = power; } void PowerGrad::SetScale(float scale) { this->primitive_->value.AsPowerGrad()->scale = scale; } void PowerGrad::SetShift(float shift) { this->primitive_->value.AsPowerGrad()->shift = shift; } - +int PowerGrad::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_PowerGrad; + } + if (this->primitive_->value.type != schema::PrimitiveType_PowerGrad) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::PowerGradT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + attr->power = GetValue(prim.GetAttr("power")); + attr->scale = GetValue(prim.GetAttr("scale")); + attr->shift = GetValue(prim.GetAttr("shift")); + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else float PowerGrad::GetPower() const { return this->primitive_->value_as_PowerGrad()->power(); } diff --git a/mindspore/lite/src/ops/power_grad.h b/mindspore/lite/src/ops/power_grad.h index a3fbd79986f67e696563925da7694487f24b9dde..4c98e37f10b4e9f6e89bb8c50f01c3388808ef46 100644 --- a/mindspore/lite/src/ops/power_grad.h +++ b/mindspore/lite/src/ops/power_grad.h @@ -34,6 +34,7 @@ class PowerGrad : public PrimitiveC { void SetPower(float power); void SetScale(float scale); void SetShift(float shift); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else PowerGrad() = default; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 961cd6cf1846a7b8f5d4e43d507a5b544b112d15..fa52d02c6881f32123cedb9f105f8849e17907b4 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -383,6 +383,20 @@ std::shared_ptr PrimitiveC::UnPackFromPrimitive(const Primitive &pri return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "BatchNormGrad") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Conv2DGradInput") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Conv2DGradFilter") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "BiasGrad") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "ActivationGrad") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "PoolingGrad") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "BNGradInput") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "PowerGrad") { + return NewPrimitiveC(prim, inputs, quantType); #endif } else { MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type; @@ -620,6 +634,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT return new ArithmeticGrad(primitive); case schema::PrimitiveType_DivGrad: return new ArithmeticGrad(primitive); + case schema::PrimitiveType_PowerGrad: + return new PowerGrad(primitive); + case schema::PrimitiveType_BNGradInput: + return new BNGradInput(primitive); #endif default: