提交 fc1d615f 编写于 作者: K kai00

grad ops add

上级 bbafa9db
......@@ -179,6 +179,7 @@ union PrimitiveType {
Conv2DGradInput,
PoolingGrad,
BNGrad,
BNGradInput,
ApplyMomentum,
BiasGrad,
SoftmaxCrossEntropy,
......
......@@ -398,7 +398,10 @@ table BNGrad {
eps : float;
momentum: float;
}
table BNGradInput {
eps : float;
momentum: float;
}
table Scale {
axis: int;
}
......
......@@ -25,6 +25,36 @@ void ActivationGrad::SetType(int type) {
this->primitive_->value.AsActivationGrad()->type = (schema::ActivationType)type;
}
void ActivationGrad::SetAlpha(float alpha) { this->primitive_->value.AsActivationGrad()->alpha = alpha; }
int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_ActivationGrad;
}
if (this->primitive_->value.type != schema::PrimitiveType_ActivationGrad) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
auto attr = std::make_unique<schema::ActivationGradT>();
if (prim.name() == "ReLU") {
attr->type = schema::ActivationType_RELU;
} else if (prim.name() == "Sigmoid") {
attr->type = schema::ActivationType_SIGMOID;
} else if (prim.name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6;
}
auto alpha = GetValue<float>(prim.GetAttr("alpha"));
attr->alpha = alpha;
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
return RET_OK;
}
#else
int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
......
......@@ -20,6 +20,7 @@
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
......@@ -33,6 +34,7 @@ class ActivationGrad : public PrimitiveC {
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetType(int type);
void SetAlpha(float alpha);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
ActivationGrad() = default;
......
......@@ -22,7 +22,34 @@ namespace lite {
std::vector<int> BiasGrad::GetAxis() const { return this->primitive_->value.AsBiasGrad()->axis; }
void BiasGrad::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; }
int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_BiasGrad;
}
if (this->primitive_->value.type != schema::PrimitiveType_BiasGrad) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::BiasGradT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int BiasGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
......
......@@ -33,7 +33,7 @@ class BiasGrad : public PrimitiveC {
BiasGrad() = default;
explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(const std::vector<int> &axis);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
BiasGrad() = default;
......
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_
#include <vector>
#include <set>
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/bn_grad_input.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float BNGradInput::GetEps() const { return this->primitive_->value.AsBNGradInput()->eps; }
float BNGradInput::GetMomentum() const { return this->primitive_->value.AsBNGradInput()->momentum; }
void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->eps = eps; }
void BNGradInput::SetMomentum(float momentum) { this->primitive_->value.AsBNGradInput()->momentum = momentum; }
int BNGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_BNGradInput;
}
if (this->primitive_->value.type != schema::PrimitiveType_BNGradInput) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::BNGradInputT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->eps = GetValue<float>(prim.GetAttr("eps"));
attr->momentum = GetValue<float>(prim.GetAttr("momentum"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int BNGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BNGradInput();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_BNGradInputInput return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateBNGradInput(*fbb, attr->eps(), attr->momentum());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGradInput, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); }
float BNGradInput::GetMomentum() const { return this->primitive_->value_as_BNGradInput()->momentum(); }
#endif
} // namespace lite
} // namespace mindspore
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class BNGradInput : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BNGradInput, PrimitiveC);
BNGradInput() = default;
explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetEps(float eps);
void SetMomentum(float momentum);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
BNGradInput() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetEps() const;
float GetMomentum() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
......@@ -66,7 +66,133 @@ void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive_->value.AsCon
void Conv2DGradFilter::SetActivationType(int activation_type) {
this->primitive_->value.AsConv2DGradFilter()->activationType = (schema::ActivationType)activation_type;
}
void Conv2DGradFilter::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
const std::vector<AnfNodePtr> &inputs) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
}
attr->channelMultiplier = channel_mutiplier;
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
}
void Conv2DGradFilter::PopulaterConv2DSingleGroup(const Primitive &prim,
schema::PrimitiveT *primitive, const int &group) {
auto attr = std::make_unique<schema::Conv2DT>();
attr->group = group;
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
}
int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Conv2DGradFilter;
}
if (this->primitive_->value.type != schema::PrimitiveType_Conv2DGradFilter) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
int group = GetValue<int>(prim.GetAttr("group"));
if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
} else {
PopulaterConv2DSingleGroup(prim, this->primitive_, group);
}
return RET_OK;
}
#else
int Conv2DGradFilter::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
......
......@@ -20,6 +20,8 @@
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include <string>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
......@@ -48,6 +50,10 @@ class Conv2DGradFilter : public PrimitiveC {
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
const std::vector<AnfNodePtr> &inputs);
void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group);
#else
Conv2DGradFilter() = default;
......
......@@ -64,7 +64,133 @@ void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive_->value.AsConv
void Conv2DGradInput::SetActivationType(int activation_type) {
this->primitive_->value.AsConv2DGradInput()->activationType = (schema::ActivationType)activation_type;
}
void Conv2DGradInput::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
const std::vector<AnfNodePtr> &inputs) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
}
attr->channelMultiplier = channel_mutiplier;
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
}
void Conv2DGradInput::PopulaterConv2DSingleGroup(const Primitive &prim,
schema::PrimitiveT *primitive, const int &group) {
auto attr = std::make_unique<schema::Conv2DT>();
attr->group = group;
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
}
int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Conv2DGradInput;
}
if (this->primitive_->value.type != schema::PrimitiveType_Conv2DGradInput) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
int group = GetValue<int>(prim.GetAttr("group"));
if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
} else {
PopulaterConv2DSingleGroup(prim, this->primitive_, group);
}
return RET_OK;
}
#else
int Conv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
......
......@@ -20,6 +20,8 @@
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include <string>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
......@@ -48,6 +50,10 @@ class Conv2DGradInput : public PrimitiveC {
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
const std::vector<AnfNodePtr> &inputs);
void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group);
#else
Conv2DGradInput() = default;
......
......@@ -52,7 +52,64 @@ void PoolingGrad::SetPadRight(int pad_right) { this->primitive_->value.AsPooling
void PoolingGrad::SetRoundMode(int round_mode) {
this->primitive_->value.AsPoolingGrad()->roundMode = (schema::RoundMode)round_mode;
}
int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_PoolingGrad;
}
if (this->primitive_->value.type != schema::PrimitiveType_PoolingGrad) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::PoolingGradT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
if (prim.instance_name() == "MaxPool") {
attr->poolingMode = schema::PoolMode_MAX_POOLING;
} else if (prim.instance_name() == "MeanPool") {
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
}
auto pad_mode = GetValue<std::string>(prim.GetAttr("padding"));
if (pad_mode == "VALID") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "SAME") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize"));
attr->windowH = kernel_size[2];
attr->windowW = kernel_size[3];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides"));
attr->strideH = stride[2];
attr->strideW = stride[3];
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int PoolingGrad::GetFormat() const { return this->primitive_->value_as_PoolingGrad()->format(); }
......
......@@ -20,6 +20,7 @@
#include <vector>
#include <set>
#include <cmath>
#include <string>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
......@@ -44,6 +45,7 @@ class PoolingGrad : public PrimitiveC {
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetRoundMode(int round_mode);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
PoolingGrad() = default;
......
......@@ -26,7 +26,36 @@ float PowerGrad::GetShift() const { return this->primitive_->value.AsPowerGrad()
void PowerGrad::SetPower(float power) { this->primitive_->value.AsPowerGrad()->power = power; }
void PowerGrad::SetScale(float scale) { this->primitive_->value.AsPowerGrad()->scale = scale; }
void PowerGrad::SetShift(float shift) { this->primitive_->value.AsPowerGrad()->shift = shift; }
int PowerGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_PowerGrad;
}
if (this->primitive_->value.type != schema::PrimitiveType_PowerGrad) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::PowerGradT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->power = GetValue<float>(prim.GetAttr("power"));
attr->scale = GetValue<float>(prim.GetAttr("scale"));
attr->shift = GetValue<float>(prim.GetAttr("shift"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
float PowerGrad::GetPower() const { return this->primitive_->value_as_PowerGrad()->power(); }
......
......@@ -34,6 +34,7 @@ class PowerGrad : public PrimitiveC {
void SetPower(float power);
void SetScale(float scale);
void SetShift(float shift);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
PowerGrad() = default;
......
......@@ -383,6 +383,20 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri
return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType);
} else if (op_type == "BatchNormGrad") {
return NewPrimitiveC<BNGrad>(prim, inputs, quantType);
} else if (op_type == "Conv2DGradInput") {
return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType);
} else if (op_type == "Conv2DGradFilter") {
return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType);
} else if (op_type == "BiasGrad") {
return NewPrimitiveC<BiasGrad>(prim, inputs, quantType);
} else if (op_type == "ActivationGrad") {
return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType);
} else if (op_type == "PoolingGrad") {
return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType);
} else if (op_type == "BNGradInput") {
return NewPrimitiveC<BNGradInput>(prim, inputs, quantType);
} else if (op_type == "PowerGrad") {
return NewPrimitiveC<PowerGrad>(prim, inputs, quantType);
#endif
} else {
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type;
......@@ -620,6 +634,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
return new ArithmeticGrad(primitive);
case schema::PrimitiveType_DivGrad:
return new ArithmeticGrad(primitive);
case schema::PrimitiveType_PowerGrad:
return new PowerGrad(primitive);
case schema::PrimitiveType_BNGradInput:
return new BNGradInput(primitive);
#endif
default:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册