提交 04e3d8b2 编写于 作者: S Shiyuan Shang-Guan 提交者: Li Xinqi

add scalar_mul (#1553)



Former-commit-id: bc5b1c935372311367de69e38c37db543b30a19d
上级 e0c32b4f
#include "oneflow/core/kernel/scalar_mul_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
void ScalarMulKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* in_blob = BnInOp2Blob("in");
Blob* out_blob = BnInOp2Blob("out");
Memcpy<device_type>(ctx.device_ctx, out_blob->mut_dptr<T>(), in_blob->dptr<T>(),
out_blob->ByteSizeOfDataContentField());
KernelUtil<device_type, T>::Scal(ctx.device_ctx, out_blob->shape().elem_cnt(),
static_cast<T>(this->op_conf().scalar_mul_conf().scalar()),
out_blob->mut_dptr<T>(), 1);
}
template<DeviceType device_type, typename T>
void ScalarMulKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out"));
Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in"));
Memcpy<device_type>(ctx.device_ctx, in_diff_blob->mut_dptr<T>(), out_diff_blob->dptr<T>(),
out_diff_blob->ByteSizeOfDataContentField());
KernelUtil<device_type, T>::Scal(ctx.device_ctx, in_diff_blob->shape().elem_cnt(),
static_cast<T>(this->op_conf().scalar_mul_conf().scalar()),
in_diff_blob->mut_dptr<T>(), 1);
}
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kScalarMulConf, ScalarMulKernel, FLOATING_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class ScalarMulKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(ScalarMulKernel);
ScalarMulKernel() = default;
~ScalarMulKernel() = default;
private:
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
void BackwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
const PbMessage& GetCustomizedOpConf() const override {
return this->op_conf().scalar_mul_conf();
}
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_
......@@ -674,6 +674,12 @@ message HingeLossOpConf {
optional Norm norm = 7[default = L1];
}
message ScalarMulOpConf {
required string in = 1;
required string out = 2;
required float scalar = 3;
}
message OperatorConf {
required string name = 1;
optional string model_load_dir = 2;
......@@ -742,6 +748,7 @@ message OperatorConf {
LossPrintOpConf loss_print_conf = 235;
DefineTestBlobConf define_test_blob_conf = 236;
PReluOpConf prelu_conf = 237;
ScalarMulOpConf scalar_mul_conf = 238;
}
}
......
#include "oneflow/core/operator/scalar_mul_op.h"
namespace oneflow {
void ScalarMulOp::InitFromOpConf() {
EnrollInputBn("in");
EnrollOutputBn("out");
}
void ScalarMulOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
*GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
}
REGISTER_OP(OperatorConf::kScalarMulConf, ScalarMulOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_
#define ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class ScalarMulOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(ScalarMulOp);
ScalarMulOp() = default;
~ScalarMulOp() = default;
void InitFromOpConf() override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
bool IsElemWiseOp() const override { return true; }
const PbMessage& GetCustomizedConf() const override { return op_conf().scalar_mul_conf(); }
bool NeedInBlobWhenBackward() const override { return false; }
bool NeedOutBlobWhenBackward() const override { return false; }
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册