提交 137537b1 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into mnist

...@@ -19,8 +19,13 @@ limitations under the License. */ ...@@ -19,8 +19,13 @@ limitations under the License. */
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <mutex> // for call_once #include <mutex> // for call_once
#include "glog/logging.h"
#include "paddle/memory/detail/buddy_allocator.h" #include "paddle/memory/detail/buddy_allocator.h"
#include "paddle/memory/detail/system_allocator.h" #include "paddle/memory/detail/system_allocator.h"
#include "paddle/platform/gpu_info.h"
DECLARE_double(fraction_of_gpu_memory_to_use);
namespace paddle { namespace paddle {
namespace memory { namespace memory {
...@@ -80,6 +85,11 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { ...@@ -80,6 +85,11 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
platform::GpuMinChunkSize(), platform::GpuMinChunkSize(),
platform::GpuMaxChunkSize())); platform::GpuMaxChunkSize()));
} }
VLOG(3) << "\n\nNOTE: each GPU device use "
<< FLAGS_fraction_of_gpu_memory_to_use * 100 << "% of GPU memory.\n"
<< "You can set environment variable '"
<< platform::kEnvFractionGpuMemoryToUse
<< "' to change the fraction of GPU usage.\n\n";
}); });
platform::SetDeviceId(gpu_id); platform::SetDeviceId(gpu_id);
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/platform/gpu_info.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
namespace paddle { namespace paddle {
......
...@@ -25,8 +25,8 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA, ...@@ -25,8 +25,8 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const float alpha, const float* A, const float alpha, const float* A,
const float* B, const float beta, float* C, const float* B, const float beta, float* C,
platform::DeviceContext* context) { platform::DeviceContext* context) {
int lda = K; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = N; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc); beta, C, ldc);
...@@ -40,8 +40,8 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA, ...@@ -40,8 +40,8 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const double* B, const double beta, const double* B, const double beta,
double* C, double* C,
platform::DeviceContext* context) { platform::DeviceContext* context) {
int lda = K; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = N; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc); beta, C, ldc);
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor;
class MulOp : public framework::OperatorWithKernel { class MulOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -59,10 +61,23 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -59,10 +61,23 @@ class MulOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {
std::string DebugString() const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
LOG(INFO) << "MulGrad"; PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
return ""; PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE(x_dims[0] == out_dims[0],
"Out@GRAD M X N must equal to X dims 0, M ");
PADDLE_ENFORCE(y_dims[1] == out_dims[1],
"Out@GRAD M X N must equal to Y dims 1, N ");
x_grad->Resize(x_dims);
y_grad->Resize(y_dims);
} }
}; };
...@@ -72,3 +87,5 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -72,3 +87,5 @@ class MulOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::CPUPlace, float>);
...@@ -17,3 +17,5 @@ ...@@ -17,3 +17,5 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::GPUPlace, float>);
...@@ -31,18 +31,34 @@ template <typename Place, typename T> ...@@ -31,18 +31,34 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel { class MulKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = { auto* X = context.Input<Tensor>("X");
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}}; auto* Y = context.Input<Tensor>("Y");
auto* input0 = context.Input<Tensor>("X"); auto* Z = context.Output<Tensor>("Out");
auto* input1 = context.Input<Tensor>("Y"); Z->mutable_data<T>(context.GetPlace());
auto* output = context.Output<Tensor>("Out"); auto* device_context =
output->mutable_data<T>(context.GetPlace()); const_cast<platform::DeviceContext*>(context.device_context_);
auto X = EigenMatrix<T>::From(*input0); math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context);
auto Y = EigenMatrix<T>::From(*input1); }
auto Z = EigenMatrix<T>::From(*output); };
auto& place = context.GetEigenDevice<Place>();
template <typename Place, typename T>
Z.device(place) = X.contract(Y, dim_pair); class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* X = ctx.Input<Tensor>("X");
auto* Y = ctx.Input<Tensor>("Y");
auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
dX->mutable_data<T>(ctx.GetPlace());
dY->mutable_data<T>(ctx.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
// dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
// dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);
} }
}; };
......
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel { using framework::Tensor;
class RowwiseAddOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -34,9 +36,9 @@ class RowWiseAddOp : public framework::OperatorWithKernel { ...@@ -34,9 +36,9 @@ class RowWiseAddOp : public framework::OperatorWithKernel {
} }
}; };
class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker { class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(framework::OpProto *proto, RowwiseAddOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix"); AddInput("X", "The left input of row-wise add op, must be matrix");
...@@ -49,12 +51,32 @@ for i in xrange(X.shape[0]): ...@@ -49,12 +51,32 @@ for i in xrange(X.shape[0]):
)DOC"); )DOC");
} }
}; };
class RowwiseAddGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto dims0 = ctx.Input<Tensor>("X")->dims();
auto dims1 = ctx.Input<Tensor>("b")->dims();
PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(rowwise_add, ops::RowWiseAddOp, REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker,
ops::RowWiseAddOpMaker); rowwise_add_grad, ops::RowwiseAddGradOp);
REGISTER_OP_CPU_KERNEL(
rowwise_add, ops::RowwiseAddKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
rowwise_add, ops::RowWiseAddKernel<paddle::platform::CPUPlace, float>); rowwise_add_grad,
ops::RowwiseAddGradKernel<paddle::platform::CPUPlace, float>);
...@@ -17,4 +17,4 @@ ...@@ -17,4 +17,4 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
rowwise_add, ops::RowWiseAddKernel<paddle::platform::GPUPlace, float>); rowwise_add, ops::RowwiseAddKernel<paddle::platform::GPUPlace, float>);
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class RowWiseAddKernel : public framework::OpKernel { class RowwiseAddKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output<Tensor>("Out"); auto out = context.Output<Tensor>("Out");
...@@ -47,5 +47,25 @@ class RowWiseAddKernel : public framework::OpKernel { ...@@ -47,5 +47,25 @@ class RowWiseAddKernel : public framework::OpKernel {
} }
}; };
template <typename Place, typename T>
class RowwiseAddGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
auto* db = context.Output<Tensor>(framework::GradVarName("b"));
dX->mutable_data<T>(context.GetPlace());
db->mutable_data<T>(context.GetPlace());
auto OutGrad = EigenMatrix<T>::From(*dOut);
auto place = context.GetEigenDevice<Place>();
EigenMatrix<T>::From(*dX).device(place) = OutGrad;
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// colwise add
Eigen::array<int, 1> dims{{1}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = OutGrad.sum(dims);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -30,7 +30,7 @@ class SGDOpKernel : public framework::OpKernel { ...@@ -30,7 +30,7 @@ class SGDOpKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad"); auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>(0); auto param_out = ctx.Output<Tensor>("param_out");
float lr = ctx.op_.GetAttr<float>("learning_rate"); float lr = ctx.op_.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
......
cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog) cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog)
cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info) cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags) nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog)
cc_library(place SRCS place.cc) cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags) cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
...@@ -9,6 +9,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) ...@@ -9,6 +9,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
add_subdirectory(dynload) add_subdirectory(dynload)
cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece) cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece)
cc_test(environment_test SRCS environment_test.cc DEPS stringpiece)
IF(WITH_GPU) IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader) set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdlib.h>
#include <unistd.h>
#include <vector>
#include "paddle/platform/enforce.h"
#include "paddle/string/piece.h"
extern char** environ; // for environment variables
namespace paddle {
namespace platform {
inline void SetEnvVariable(const std::string& name, const std::string& value) {
PADDLE_ENFORCE_NE(setenv(name.c_str(), value.c_str(), 1), -1,
"Failed to set environment variable %s=%s", name, value);
}
inline void UnsetEnvVariable(const std::string& name) {
PADDLE_ENFORCE_NE(unsetenv(name.c_str()), -1,
"Failed to unset environment variable %s", name);
}
inline bool IsEnvVarDefined(const std::string& name) {
return std::getenv(name.c_str()) != nullptr;
}
inline std::string GetEnvValue(const std::string& name) {
PADDLE_ENFORCE(IsEnvVarDefined(name),
"Tried to access undefined environment variable %s", name);
return std::getenv(name.c_str());
}
inline std::vector<std::string> GetAllEnvVariables() {
std::vector<std::string> vars;
for (auto var = environ; *var != nullptr; ++var) {
auto tail = string::Index(*var, "=");
auto name = string::SubStr(*var, 0, tail).ToString();
vars.push_back(name);
}
return vars;
}
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/environment.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
TEST(ENVIRONMENT, ACCESS) {
namespace platform = paddle::platform;
namespace string = paddle::string;
platform::SetEnvVariable("PADDLE_USE_ENV", "TRUE");
EXPECT_TRUE(platform::IsEnvVarDefined("PADDLE_USE_ENV"));
EXPECT_EQ(platform::GetEnvValue("PADDLE_USE_ENV"), "TRUE");
platform::UnsetEnvVariable("PADDLE_USE_ENV");
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV"));
platform::SetEnvVariable("PADDLE_USE_ENV1", "Hello ");
platform::SetEnvVariable("PADDLE_USE_ENV2", "World, ");
platform::SetEnvVariable("PADDLE_USE_ENV3", "PaddlePaddle!");
std::string env_info;
auto vars = platform::GetAllEnvVariables();
for_each(vars.begin(), vars.end(), [&](const std::string& var) {
env_info += platform::GetEnvValue(var);
});
EXPECT_TRUE(string::Contains(env_info, "Hello World, PaddlePaddle!"));
platform::UnsetEnvVariable("PADDLE_USE_ENV1");
platform::UnsetEnvVariable("PADDLE_USE_ENV2");
platform::UnsetEnvVariable("PADDLE_USE_ENV3");
env_info.clear();
vars = platform::GetAllEnvVariables();
for_each(vars.begin(), vars.end(), [&](const std::string& var) {
env_info += platform::GetEnvValue(var);
});
EXPECT_FALSE(string::Contains(env_info, "Hello World, PaddlePaddle!"));
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV1"));
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV2"));
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV3"));
}
...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/environment.h"
DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, DEFINE_double(fraction_of_gpu_memory_to_use, 0.95,
"Default use 95% of GPU memory for PaddlePaddle," "Default use 95% of GPU memory for PaddlePaddle,"
...@@ -70,6 +73,13 @@ size_t GpuMaxChunkSize() { ...@@ -70,6 +73,13 @@ size_t GpuMaxChunkSize() {
GpuMemoryUsage(available, total); GpuMemoryUsage(available, total);
if (IsEnvVarDefined(kEnvFractionGpuMemoryToUse)) {
auto val = std::stod(GetEnvValue(kEnvFractionGpuMemoryToUse));
PADDLE_ENFORCE_GT(val, 0.0);
PADDLE_ENFORCE_LE(val, 1.0);
FLAGS_fraction_of_gpu_memory_to_use = val;
}
// Reserving the rest memory for page tables, etc. // Reserving the rest memory for page tables, etc.
size_t reserving = (1 - FLAGS_fraction_of_gpu_memory_to_use) * total; size_t reserving = (1 - FLAGS_fraction_of_gpu_memory_to_use) * total;
......
...@@ -18,10 +18,15 @@ limitations under the License. */ ...@@ -18,10 +18,15 @@ limitations under the License. */
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <stddef.h> #include <stddef.h>
#include <string>
namespace paddle { namespace paddle {
namespace platform { namespace platform {
//! Environment variable: fraction of GPU memory to use on each device.
const std::string kEnvFractionGpuMemoryToUse =
"PADDLE_FRACTION_GPU_MEMORY_TO_USE";
//! Get the total number of GPU devices in system. //! Get the total number of GPU devices in system.
int GetDeviceCount(); int GetDeviceCount();
......
...@@ -25,4 +25,5 @@ py_test(test_operator SRCS test_operator.py) ...@@ -25,4 +25,5 @@ py_test(test_operator SRCS test_operator.py)
# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) # py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_sgd_op SRCS test_sgd_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py) py_test(test_gradient_checker SRCS test_gradient_checker.py)
import unittest import unittest
from op_test_util import OpTestMeta
import numpy as np import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
class TestMulOp(unittest.TestCase): class TestMulOp(unittest.TestCase):
...@@ -15,5 +16,19 @@ class TestMulOp(unittest.TestCase): ...@@ -15,5 +16,19 @@ class TestMulOp(unittest.TestCase):
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
class MulGradOpTest(GradientChecker):
def test_mul(self):
op = create_op("mul")
inputs = {
'X': np.random.random((32, 84)).astype("float32"),
'Y': np.random.random((84, 100)).astype("float32")
}
# mul op will enlarge the relative error
self.check_grad(
op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5)
# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
import unittest import unittest
from op_test_util import OpTestMeta
import numpy as np import numpy as np
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
class TestRowwiseAddOp(unittest.TestCase): class TestRowwiseAddOp(unittest.TestCase):
...@@ -15,5 +16,15 @@ class TestRowwiseAddOp(unittest.TestCase): ...@@ -15,5 +16,15 @@ class TestRowwiseAddOp(unittest.TestCase):
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
class RowwiseAddGradOpTest(GradientChecker):
def test_rowwise_add(self):
op = create_op("rowwise_add")
inputs = {
"X": np.random.uniform(0.1, 1, [10, 10]).astype("float32"),
"b": np.random.uniform(0.1, 1, [10]).astype("float32")
}
self.check_grad(op, inputs, set(["X", "b"]), "Out")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册