diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 0266bf4f7d65c7aafd4242af41cbd1c71f44bff8..29bc26f9d3bca0e30896657431f9a9bb1dac0d1d 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -19,8 +19,13 @@ limitations under the License. */ #include // for unique_ptr #include // for call_once +#include "glog/logging.h" + #include "paddle/memory/detail/buddy_allocator.h" #include "paddle/memory/detail/system_allocator.h" +#include "paddle/platform/gpu_info.h" + +DECLARE_double(fraction_of_gpu_memory_to_use); namespace paddle { namespace memory { @@ -80,6 +85,11 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { platform::GpuMinChunkSize(), platform::GpuMaxChunkSize())); } + VLOG(3) << "\n\nNOTE: each GPU device use " + << FLAGS_fraction_of_gpu_memory_to_use * 100 << "% of GPU memory.\n" + << "You can set environment variable '" + << platform::kEnvFractionGpuMemoryToUse + << "' to change the fraction of GPU usage.\n\n"; }); platform::SetDeviceId(gpu_id); diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 72351b9dfa63513713463bb47a3684f0dfd84ad3..11bbb881874ec50e1132547336fc6fb6b42bcc4f 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include "paddle/platform/gpu_info.h" #include "paddle/platform/place.h" namespace paddle { diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index affdd1ac2cd486930881ee6b34a4b32f41df7ee9..1e86fc3d166077265e0f433a6712b0665ea5a152 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -25,8 +25,8 @@ void gemm(const CBLAS_TRANSPOSE transA, const float alpha, const float* A, const float* B, const float beta, float* C, platform::DeviceContext* context) { - int lda = K; - int ldb = N; + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); @@ -40,8 +40,8 @@ void gemm(const CBLAS_TRANSPOSE transA, const double* B, const double beta, double* C, platform::DeviceContext* context) { - int lda = K; - int ldb = N; + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 95d19fb6aad37143e65759b03e12e3e78bce5915..460e458ca4f7f40746f0dbf7e258a165faa88e1a 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -18,6 +18,8 @@ namespace paddle { namespace operators { +using framework::Tensor; + class MulOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -59,10 +61,23 @@ class MulOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "MulGrad"; - return ""; + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + auto *y_grad = ctx.Output(framework::GradVarName("Y")); + PADDLE_ENFORCE(x_dims[0] == out_dims[0], + "Out@GRAD M X N must equal to X dims 0, M "); + PADDLE_ENFORCE(y_dims[1] == out_dims[1], + "Out@GRAD M X N must equal to Y dims 1, N "); + + x_grad->Resize(x_dims); + y_grad->Resize(y_dims); } }; @@ -72,3 +87,5 @@ class MulOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_CPU_KERNEL(mul_grad, + ops::MulGradKernel); diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 346a7e505d123b5e4e831daa39a1f6349b3dcccf..a81444dbe63edeecedc5d822c65ff56c42b5db90 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -17,3 +17,5 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_GPU_KERNEL(mul_grad, + ops::MulGradKernel); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index b7812fd1a7a72f5ce543e18c8b7b5b51deff2204..8facc0281449785bf40726f23ca2fd5d166ff272 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -31,18 +31,34 @@ template class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - Eigen::array, 1> dim_pair = { - {Eigen::IndexPair(1, 0)}}; - auto* input0 = context.Input("X"); - auto* input1 = context.Input("Y"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); - auto X = EigenMatrix::From(*input0); - auto Y = EigenMatrix::From(*input1); - auto Z = EigenMatrix::From(*output); - auto& place = context.GetEigenDevice(); - - Z.device(place) = X.contract(Y, dim_pair); + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* Z = context.Output("Out"); + Z->mutable_data(context.GetPlace()); + auto* device_context = + const_cast(context.device_context_); + math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); + } +}; + +template +class MulGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); + + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dY = ctx.Output(framework::GradVarName("Y")); + dX->mutable_data(ctx.GetPlace()); + dY->mutable_data(ctx.GetPlace()); + auto* device_context = + const_cast(ctx.device_context_); + // dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N + math::matmul(*dOut, false, *Y, true, 1, dX, 0, device_context); + // dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K + math::matmul(*X, true, *dOut, false, 1, dY, 0, device_context); } }; diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 8375d988045dc24fa1109646b46ff477e2a78132..6825dce332adc0dc11dda187d1bd367875b8603e 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -17,7 +17,9 @@ namespace paddle { namespace operators { -class RowWiseAddOp : public framework::OperatorWithKernel { +using framework::Tensor; + +class RowwiseAddOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -34,9 +36,9 @@ class RowWiseAddOp : public framework::OperatorWithKernel { } }; -class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker { +class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker { public: - RowWiseAddOpMaker(framework::OpProto *proto, + RowwiseAddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The left input of row-wise add op, must be matrix"); @@ -49,12 +51,32 @@ for i in xrange(X.shape[0]): )DOC"); } }; +class RowwiseAddGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dims0 = ctx.Input("X")->dims(); + auto dims1 = ctx.Input("b")->dims(); + PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1") + ctx.Output(framework::GradVarName("X"))->Resize(dims0); + ctx.Output(framework::GradVarName("b"))->Resize(dims1); + } +}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(rowwise_add, ops::RowWiseAddOp, - ops::RowWiseAddOpMaker); +REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker, + rowwise_add_grad, ops::RowwiseAddGradOp); +REGISTER_OP_CPU_KERNEL( + rowwise_add, ops::RowwiseAddKernel); REGISTER_OP_CPU_KERNEL( - rowwise_add, ops::RowWiseAddKernel); + rowwise_add_grad, + ops::RowwiseAddGradKernel); diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index 86f80b81228a69ac4c05a4693901570f2b9966e0..cbc61ad3e117fc79a674ca21831d3fec59d1ec5b 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -17,4 +17,4 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - rowwise_add, ops::RowWiseAddKernel); + rowwise_add, ops::RowwiseAddKernel); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 01f88f2198774fbaa4c98ff9bf286f2f08496a9a..232135c38de68d4002e044972b282b43a1374c72 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -28,7 +28,7 @@ template ; template -class RowWiseAddKernel : public framework::OpKernel { +class RowwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto out = context.Output("Out"); @@ -47,5 +47,25 @@ class RowWiseAddKernel : public framework::OpKernel { } }; +template +class RowwiseAddGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dOut = context.Input(framework::GradVarName("Out")); + auto* dX = context.Output(framework::GradVarName("X")); + auto* db = context.Output(framework::GradVarName("b")); + dX->mutable_data(context.GetPlace()); + db->mutable_data(context.GetPlace()); + + auto OutGrad = EigenMatrix::From(*dOut); + auto place = context.GetEigenDevice(); + EigenMatrix::From(*dX).device(place) = OutGrad; + + // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html + // colwise add + Eigen::array dims{{1}}; /* dimension to reduce */ + EigenVector::Flatten(*db).device(place) = OutGrad.sum(dims); + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index bfb449d0b029409eda4177fc7643810ee6a1df3d..a0b5000ffbf54364e15f87870913926a071fa972 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -30,7 +30,7 @@ class SGDOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto param = ctx.Input("param"); auto grad = ctx.Input("grad"); - auto param_out = ctx.Output(0); + auto param_out = ctx.Output("param_out"); float lr = ctx.op_.GetAttr("learning_rate"); param_out->mutable_data(ctx.GetPlace()); diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index acfc0639736beb82df41b851664e7bcd079b5eb1..120eb1e4af9cef43e76e27d4ad66acfbbd597a36 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,7 +1,7 @@ cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog) cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info) -nv_library(gpu_info SRCS gpu_info.cc DEPS gflags) +nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) @@ -9,6 +9,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) add_subdirectory(dynload) cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece) +cc_test(environment_test SRCS environment_test.cc DEPS stringpiece) IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) diff --git a/paddle/platform/environment.h b/paddle/platform/environment.h new file mode 100644 index 0000000000000000000000000000000000000000..4edcce932edc61453cef74f2c4ee0f72496b3677 --- /dev/null +++ b/paddle/platform/environment.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/platform/enforce.h" +#include "paddle/string/piece.h" + +extern char** environ; // for environment variables + +namespace paddle { +namespace platform { + +inline void SetEnvVariable(const std::string& name, const std::string& value) { + PADDLE_ENFORCE_NE(setenv(name.c_str(), value.c_str(), 1), -1, + "Failed to set environment variable %s=%s", name, value); +} + +inline void UnsetEnvVariable(const std::string& name) { + PADDLE_ENFORCE_NE(unsetenv(name.c_str()), -1, + "Failed to unset environment variable %s", name); +} + +inline bool IsEnvVarDefined(const std::string& name) { + return std::getenv(name.c_str()) != nullptr; +} + +inline std::string GetEnvValue(const std::string& name) { + PADDLE_ENFORCE(IsEnvVarDefined(name), + "Tried to access undefined environment variable %s", name); + return std::getenv(name.c_str()); +} + +inline std::vector GetAllEnvVariables() { + std::vector vars; + for (auto var = environ; *var != nullptr; ++var) { + auto tail = string::Index(*var, "="); + auto name = string::SubStr(*var, 0, tail).ToString(); + vars.push_back(name); + } + return vars; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/environment_test.cc b/paddle/platform/environment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f136527215d6a676cfa1a3b08f09dfd3ab24a90 --- /dev/null +++ b/paddle/platform/environment_test.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/environment.h" + +#include "glog/logging.h" +#include "gtest/gtest.h" + +TEST(ENVIRONMENT, ACCESS) { + namespace platform = paddle::platform; + namespace string = paddle::string; + + platform::SetEnvVariable("PADDLE_USE_ENV", "TRUE"); + + EXPECT_TRUE(platform::IsEnvVarDefined("PADDLE_USE_ENV")); + EXPECT_EQ(platform::GetEnvValue("PADDLE_USE_ENV"), "TRUE"); + + platform::UnsetEnvVariable("PADDLE_USE_ENV"); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV")); + + platform::SetEnvVariable("PADDLE_USE_ENV1", "Hello "); + platform::SetEnvVariable("PADDLE_USE_ENV2", "World, "); + platform::SetEnvVariable("PADDLE_USE_ENV3", "PaddlePaddle!"); + + std::string env_info; + auto vars = platform::GetAllEnvVariables(); + for_each(vars.begin(), vars.end(), [&](const std::string& var) { + env_info += platform::GetEnvValue(var); + }); + + EXPECT_TRUE(string::Contains(env_info, "Hello World, PaddlePaddle!")); + platform::UnsetEnvVariable("PADDLE_USE_ENV1"); + platform::UnsetEnvVariable("PADDLE_USE_ENV2"); + platform::UnsetEnvVariable("PADDLE_USE_ENV3"); + + env_info.clear(); + vars = platform::GetAllEnvVariables(); + for_each(vars.begin(), vars.end(), [&](const std::string& var) { + env_info += platform::GetEnvValue(var); + }); + + EXPECT_FALSE(string::Contains(env_info, "Hello World, PaddlePaddle!")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV1")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV2")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV3")); +} diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index edeb3ecd7bf8b87333813eee5b40f71030f6609f..be381a4e26cf0eb41f5b3de88bd03ad8901683cc 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/platform/gpu_info.h" + #include "gflags/gflags.h" + #include "paddle/platform/enforce.h" +#include "paddle/platform/environment.h" DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, "Default use 95% of GPU memory for PaddlePaddle," @@ -70,6 +73,13 @@ size_t GpuMaxChunkSize() { GpuMemoryUsage(available, total); + if (IsEnvVarDefined(kEnvFractionGpuMemoryToUse)) { + auto val = std::stod(GetEnvValue(kEnvFractionGpuMemoryToUse)); + PADDLE_ENFORCE_GT(val, 0.0); + PADDLE_ENFORCE_LE(val, 1.0); + FLAGS_fraction_of_gpu_memory_to_use = val; + } + // Reserving the rest memory for page tables, etc. size_t reserving = (1 - FLAGS_fraction_of_gpu_memory_to_use) * total; diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index d3a5f5f13fdd3dd59eb43465da4a64b0d8d95e5b..ed2420b8740e583d307f6836a70fe7e1c780e28b 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -18,10 +18,15 @@ limitations under the License. */ #include #include +#include namespace paddle { namespace platform { +//! Environment variable: fraction of GPU memory to use on each device. +const std::string kEnvFractionGpuMemoryToUse = + "PADDLE_FRACTION_GPU_MEMORY_TO_USE"; + //! Get the total number of GPU devices in system. int GetDeviceCount(); diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 4c088e7612a93be1b52bc015babee382bbd9026d..ce57a0713092723b6a99b2416e06ff1a436f043b 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -25,4 +25,5 @@ py_test(test_operator SRCS test_operator.py) # py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) +py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_gradient_checker SRCS test_gradient_checker.py) diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index ec0ac99156a546dd3fb7b27778032bece38ab5a9..ee0d81a64efcb81bae8b11b856c201a86da274e9 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta class TestMulOp(unittest.TestCase): @@ -15,5 +16,19 @@ class TestMulOp(unittest.TestCase): self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} +class MulGradOpTest(GradientChecker): + def test_mul(self): + op = create_op("mul") + inputs = { + 'X': np.random.random((32, 84)).astype("float32"), + 'Y': np.random.random((84, 100)).astype("float32") + } + # mul op will enlarge the relative error + self.check_grad( + op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5) + + +# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index f8521eb517057fbeb104b28af7da4fffe54f37de..29d72e850099734a9828ccceed47cc0a57fc3d6b 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op class TestRowwiseAddOp(unittest.TestCase): @@ -15,5 +16,15 @@ class TestRowwiseAddOp(unittest.TestCase): self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} +class RowwiseAddGradOpTest(GradientChecker): + def test_rowwise_add(self): + op = create_op("rowwise_add") + inputs = { + "X": np.random.uniform(0.1, 1, [10, 10]).astype("float32"), + "b": np.random.uniform(0.1, 1, [10]).astype("float32") + } + self.check_grad(op, inputs, set(["X", "b"]), "Out") + + if __name__ == '__main__': unittest.main()