提交 5f0114a9 编写于 作者: W willzhang4a58

Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow


Former-commit-id: fdc92422
......@@ -5,6 +5,10 @@
#define final
#endif
#ifndef private
#define private public
#endif
#include <gmock/gmock.h>
#include <gtest/gtest.h>
......
......@@ -109,13 +109,13 @@ void ConvolutionKernel<device_type, T>::ForwardDataContent(
in_shape.At(3), conv_conf.kernel_h(), conv_conf.kernel_w(),
conv_conf.pad_h(), conv_conf.pad_w(), conv_conf.stride_h(),
conv_conf.stride_w(), conv_conf.dilation_h(), conv_conf.dilation_w(),
col_buf_blob->mut_dptr<T>() + i * col_im_sz);
col_buf_blob->mut_dptr<T>() + col_im_sz);
KernelUtil<device_type, T>::Gemm(
ctx.device_ctx, CBLAS_ORDER::CblasRowMajor, CblasNoTrans, CblasTrans,
out_blob->shape().At(1), out_blob->shape().Count(2),
weight_blob->shape().At(1), static_cast<T>(1.0), weight_blob->dptr<T>(),
weight_blob->shape().At(1), col_buf_blob->dptr<T>() + i * col_im_sz,
weight_blob->shape().At(1), col_buf_blob->dptr<T>() + col_im_sz,
weight_blob->shape().At(1), static_cast<T>(0.0),
out_blob->mut_dptr<T>() + i * out_im_sz, col_buf_blob->shape().At(1));
......@@ -139,16 +139,28 @@ template<DeviceType device_type, typename T>
void ConvolutionKernel<device_type, T>::ComputeWeightDiff(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* in_blob = BnInOp2Blob("in");
const Shape& in_shape = in_blob->shape();
const int64_t in_im_sz = in_shape.Count(1);
Blob* weight_diff_blob = BnInOp2Blob("weight_diff");
const Blob* col_buf_blob = BnInOp2Blob("col_buf");
Blob* col_buf_blob = BnInOp2Blob("col_buf");
const Blob* out_diff_blob = BnInOp2Blob("out_diff");
const int64_t out_im_sz = out_diff_blob->shape().Count(1);
const int64_t col_im_sz = col_buf_blob->shape().Count(1);
const int64_t data_num = out_diff_blob->shape().At(0);
const int64_t conv_sliding_window_steps = out_diff_blob->shape().Count(2);
Memset<device_type>(ctx.device_ctx, weight_diff_blob->mut_dptr(), 0,
weight_diff_blob->ByteSizeOfDataContentField());
const ConvolutionOpConf& conv_conf = this->op_conf().convolution_conf();
for (size_t i = 0; i < data_num; ++i) {
ConvolutionKernelUtil<device_type, T>::Im2Col(
ctx, in_blob->dptr<T>() + i * in_im_sz, in_shape.At(1), in_shape.At(2),
in_shape.At(3), conv_conf.kernel_h(), conv_conf.kernel_w(),
conv_conf.pad_h(), conv_conf.pad_w(), conv_conf.stride_h(),
conv_conf.stride_w(), conv_conf.dilation_h(), conv_conf.dilation_w(),
col_buf_blob->mut_dptr<T>() + col_im_sz);
KernelUtil<device_type, T>::Gemm(
ctx.device_ctx, CBLAS_ORDER::CblasRowMajor, CblasNoTrans, CblasNoTrans,
weight_diff_blob->shape().At(0), weight_diff_blob->shape().At(1),
......@@ -156,7 +168,7 @@ void ConvolutionKernel<device_type, T>::ComputeWeightDiff(
static_cast<T>(1.0) / conv_sliding_window_steps,
out_diff_blob->dptr<T>() + i * out_im_sz,
out_diff_blob->shape().Count(2),
col_buf_blob->dptr<T>() + i * col_buf_blob->shape().Count(1),
col_buf_blob->dptr<T>() + col_buf_blob->shape().Count(1),
col_buf_blob->shape().At(2), static_cast<T>(1.0),
weight_diff_blob->mut_dptr<T>(), weight_diff_blob->shape().At(1));
}
......@@ -199,6 +211,9 @@ void ConvolutionKernel<device_type, T>::ComputeInputDiff(
const int64_t out_im_sz = out_diff_blob->shape().Count(1);
const int64_t data_num = out_diff_blob->shape().At(0);
const Shape& in_diff_shape = in_diff_blob->shape();
const ConvolutionOpConf& conv_conf = this->op_conf().convolution_conf();
for (size_t i = 0; i < data_num; ++i) {
KernelUtil<device_type, T>::Gemm(
ctx.device_ctx, CBLAS_ORDER::CblasRowMajor, CblasTrans, CblasNoTrans,
......@@ -207,15 +222,11 @@ void ConvolutionKernel<device_type, T>::ComputeInputDiff(
out_diff_blob->dptr<T>() + i * out_im_sz,
out_diff_blob->shape().Count(2), weight_blob->dptr<T>(),
weight_blob->shape().At(1), static_cast<T>(0.0),
col_buf_blob->mut_dptr<T>() + i * col_buf_blob->shape().Count(1),
col_buf_blob->mut_dptr<T>() + col_buf_blob->shape().Count(1),
col_buf_blob->shape().At(2));
}
const Shape& in_diff_shape = in_diff_blob->shape();
const ConvolutionOpConf& conv_conf = this->op_conf().convolution_conf();
for (size_t i = 0; i < data_num; ++i) {
ConvolutionKernelUtil<device_type, T>::Col2Im(
ctx, col_buf_blob->dptr<T>() + i * col_buf_blob->shape().Count(1),
ctx, col_buf_blob->dptr<T>() + col_buf_blob->shape().Count(1),
in_diff_shape.At(1), in_diff_shape.At(2), in_diff_shape.At(3),
conv_conf.kernel_h(), conv_conf.kernel_w(), conv_conf.pad_h(),
conv_conf.pad_w(), conv_conf.stride_h(), conv_conf.stride_w(),
......
#include "oneflow/core/kernel/kernel_test_common.h"
#include "oneflow/core/kernel/opkernel_test_common.h"
#include <random>
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/device/cpu_device_context.h"
......@@ -7,6 +7,29 @@ namespace oneflow {
namespace test {
std::function<BlobDesc*(const std::string)> ConstructBn2BlobDescFunc(
std::shared_ptr<Operator> op) {
auto InsertBnsWithEmptyBlobDesc2Map =
[](const std::vector<std::string>& bns,
HashMap<std::string, BlobDesc*>* bn2blobdesc_map) {
for (const std::string& bn : bns) {
CHECK(bn2blobdesc_map->insert({bn, new BlobDesc}).second);
}
};
auto bn2blobdesc_map = new HashMap<std::string, BlobDesc*>();
InsertBnsWithEmptyBlobDesc2Map(op->data_tmp_bns(), bn2blobdesc_map);
InsertBnsWithEmptyBlobDesc2Map(op->input_bns(), bn2blobdesc_map);
InsertBnsWithEmptyBlobDesc2Map(op->input_diff_bns(), bn2blobdesc_map);
InsertBnsWithEmptyBlobDesc2Map(op->output_bns(), bn2blobdesc_map);
InsertBnsWithEmptyBlobDesc2Map(op->output_diff_bns(), bn2blobdesc_map);
InsertBnsWithEmptyBlobDesc2Map(op->model_bns(), bn2blobdesc_map);
InsertBnsWithEmptyBlobDesc2Map(op->model_diff_bns(), bn2blobdesc_map);
InsertBnsWithEmptyBlobDesc2Map(op->model_tmp_bns(), bn2blobdesc_map);
return [bn2blobdesc_map](const std::string& bn) {
return bn2blobdesc_map->at(bn);
};
}
template<>
Blob* CreateBlob<DeviceType::kCPU>(const BlobDesc* blob_desc) {
void* mem_ptr = nullptr;
......
#include <random>
#include "oneflow/core/device/cuda_device_context.h"
#include "oneflow/core/kernel/kernel_test_common.h"
#include "oneflow/core/kernel/opkernel_test_common.h"
namespace oneflow {
......
#ifndef ONEFLOW_CORE_KERNEL_KERNEL_TEST_COMMON_H_
#define ONEFLOW_CORE_KERNEL_KERNEL_TEST_COMMON_H_
#ifndef ONEFLOW_CORE_KERNEL_OPKERNEL_TEST_COMMON_H_
#define ONEFLOW_CORE_KERNEL_OPKERNEL_TEST_COMMON_H_
#include "oneflow/core/common/test_util.h"
#include "oneflow/core/job/resource.pb.h"
#include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/register/blob.h"
namespace oneflow {
namespace test {
std::function<BlobDesc*(const std::string)> ConstructBn2BlobDescFunc(
std::shared_ptr<Operator>);
template<DeviceType device_type>
Blob* CreateBlob(const BlobDesc*);
......@@ -63,4 +68,4 @@ class KTCommon final {
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_KERNEL_TEST_COMMON_H_
#endif // ONEFLOW_CORE_KERNEL_OPKERNEL_TEST_COMMON_H_
......@@ -57,7 +57,7 @@ void ConvolutionOp::InferBlobDescs(
// col_buf
BlobDesc* col_buf_blob_desc = GetBlobDesc4BnInOp("col_buf");
col_buf_blob_desc->mut_shape() = Shape({data_num, output_size, c_i * kernel});
col_buf_blob_desc->mut_shape() = Shape({1, output_size, c_i * kernel});
col_buf_blob_desc->set_data_type(JobDesc::Singleton()->DefaultDataType());
col_buf_blob_desc->set_has_data_id(false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册