提交 51a86d2b 编写于 作者: A Aurelius84 提交者: hong

Optimize adam speed (#21777)

* optimize adam speed by removing _finish_update test=develop

* fix SparseAdamFunctor param list test=develop

* Remove scale_op in expect_list of adam_op test=develop

* fix test optimizer loss assert error test=develop

* fix test optimizer loss assert error test=develop

* modify PADDLE_ENFORCE usage test=develop

* fix op_type in lamb_op.cc test=develop

* fix errors ostream format bug test=develop

* add betaPowOut in ngraph op test=develop

* fix ngraph::op api for gcc8 test=develop

* clean code test=develop

* modify struct into class test=develop

* remove code of beta1Tensor in lamb_op test=develop
上级 310edc0d
......@@ -39,13 +39,6 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
auto fused_adam_node =
FuseAdamOps(aux_var_set, fused_vars_name, adam_ops, graph);
auto fused_scale1 =
FuseScaleOps(aux_var_set.at("Beta1Pow"), fused_vars_name.at("Beta1Pow"),
adam_ops, graph);
auto fused_scale2 =
FuseScaleOps(aux_var_set.at("Beta2Pow"), fused_vars_name.at("Beta2Pow"),
adam_ops, graph);
RemoveCycleDepsBetweenOpNodes(graph, fused_scale1, fused_scale2);
return fused_adam_node;
}
......@@ -139,6 +132,8 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
adam_desc.SetOutput("ParamOut", {fused_vars_name.at(kParam)});
adam_desc.SetOutput("Moment1Out", {fused_vars_name.at("Moment1")});
adam_desc.SetOutput("Moment2Out", {fused_vars_name.at("Moment2")});
adam_desc.SetOutput("Beta1PowOut", {fused_vars_name.at("Beta1Pow")});
adam_desc.SetOutput("Beta2PowOut", {fused_vars_name.at("Beta2Pow")});
adam_desc.SetAttr("beta1", beta1);
adam_desc.SetAttr("beta2", beta2);
adam_desc.SetAttr("epsilon", epsilon);
......
......@@ -416,6 +416,8 @@ void FuseOptimizerOpPass::FuseVarsToContinuousSpace(
result->Get<details::ProgramDescs>(details::kProgramDescs).back();
auto *global_block = program_desc.MutableBlock(0);
for (auto &var_name : aux_var_names) {
VLOG(6) << "aux_var_names : " << var_name
<< ". fused_vars_name: " << fused_vars_name.at(var_name);
AppendCoalesceTensorOp(aux_var_map.at(var_name), aux_var_map.at(var_name),
fused_vars_name.at(var_name), dtype, global_block,
true);
......
......@@ -68,9 +68,14 @@ void BuildAdamNode(
auto delta = ElementwiseScalar<ngraph::op::Multiply>(updated_lr, param_grad);
auto param_out = std::make_shared<ngraph::op::Subtract>(param, delta);
auto beta1_pow_out = ElementwiseScalar<ngraph::op::Multiply>(beta1, beta1pow);
auto beta2_pow_out = ElementwiseScalar<ngraph::op::Multiply>(beta2, beta2pow);
platform::SetOutputNode(op, "Moment1Out", moment1out, ngb_node_map);
platform::SetOutputNode(op, "Moment2Out", moment2out, ngb_node_map);
platform::SetOutputNode(op, "ParamOut", param_out, ngb_node_map);
platform::SetOutputNode(op, "Beta1PowOut", beta1_pow_out, ngb_node_map);
platform::SetOutputNode(op, "Beta2PowOut", beta2_pow_out, ngb_node_map);
}
} // namespace ngraphs
} // namespace operators
......
......@@ -66,37 +66,63 @@ void AdamOp::InferShape(framework::InferShapeContext* ctx) const {
"Output(Moment2Out) of AdamOp should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
"Maybe the Input variable LearningRate has not "
PADDLE_ENFORCE_NE(
framework::product(lr_dims), 0,
platform::errors::InvalidArgument(
"The number of LearningRate shall not be 0, but received %d. Maybe "
"the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 dimension");
"after optimizer.minimize function.",
framework::product(lr_dims)));
PADDLE_ENFORCE_EQ(
framework::product(lr_dims), 1,
platform::errors::InvalidArgument(
"Learning rate should have 1 dimension, but received %d",
framework::product(lr_dims)));
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
"Beta1 power accumulator should have 1 dimension");
VLOG(3) << "dims of Beta1Pow : [" << beta1_pow_dims << "]";
PADDLE_ENFORCE_GE(framework::product(beta1_pow_dims), 1,
platform::errors::InvalidArgument(
"The size of Beta1 power accumulator should be greater "
"than 0, but received %d.",
framework::product(beta1_pow_dims)));
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1,
"Beta2 power accumulator should have 1 dimension");
VLOG(3) << "dims of Beta2Pow : [" << beta2_pow_dims << "]";
PADDLE_ENFORCE_GE(framework::product(beta2_pow_dims), 1,
platform::errors::InvalidArgument(
"The size of Beta2 power accumulator should be greater "
"than 0, but received %d.",
framework::product(beta2_pow_dims)));
auto param_dims = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamOp should have same dimension");
platform::errors::InvalidArgument(
"Param and Grad input of AdamOp should have same dimension. But "
"received Param dims: [%s], Grad dims: [%s].",
param_dims, ctx->GetInputDim("Grad")));
}
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment1"),
"Param and Moment1 input of AdamOp should have same dimension");
platform::errors::InvalidArgument(
"Param and Moment1 input of AdamOp should have same dimension. But "
"received Param dims: [%s], Moment1 dims: [%s].",
param_dims, ctx->GetInputDim("Moment1")));
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment2"),
"Param and Moment2 input of AdamOp should have same dimension");
platform::errors::InvalidArgument(
"Param and Moment2 input of AdamOp should have same dimension. But "
"received Param dims: [%s], Moment2 dims: [%s].",
param_dims, ctx->GetInputDim("Moment2")));
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims);
ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims);
ctx->SetOutputDim("Beta2PowOut", beta2_pow_dims);
}
framework::OpKernelType AdamOp::GetExpectedKernelType(
......@@ -130,6 +156,8 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("Moment1Out", "(Tensor) Output first moment");
AddOutput("Moment2Out", "(Tensor) Output second moment");
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator");
AddAttr<float>("beta1",
"(float, default 0.9) "
......
......@@ -52,10 +52,48 @@ struct GPUAdam;
struct CPUAdam;
template <typename T, typename Flavour>
struct AdamFunctor;
class AdamFunctor;
template <typename T>
struct AdamFunctor<T, GPUAdam> {
class BetaPowFunctor {
private:
T beta1_;
T beta2_;
const T* beta1_pow_;
const T* beta2_pow_;
T* beta1_pow_out_;
T* beta2_pow_out_;
public:
BetaPowFunctor(T beta1, T beta2, const T* beta1_pow, const T* beta2_pow,
T* beta1_pow_out, T* beta2_pow_out)
: beta1_(beta1),
beta2_(beta2),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_out_(beta2_pow_out) {}
inline HOSTDEVICE void update_step(size_t i) const {
T beta1_pow_i = beta1_pow_[i];
T beta2_pow_i = beta2_pow_[i];
beta1_pow_out_[i] = beta1_pow_i * beta1_;
beta2_pow_out_[i] = beta2_pow_i * beta2_;
}
inline HOSTDEVICE void operator()(size_t i) const { update_step(i); }
inline HOSTDEVICE void apply_update(size_t limit) const {
for (size_t i = 0; i < limit; ++i) {
update_step(i);
}
}
};
template <typename T>
class AdamFunctor<T, GPUAdam> {
private:
T beta1_;
T beta2_;
T epsilon_;
......@@ -71,6 +109,7 @@ struct AdamFunctor<T, GPUAdam> {
const T* param_;
T* param_out_;
public:
AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* lr, const T* grad, const T* param,
......@@ -114,7 +153,8 @@ struct AdamFunctor<T, GPUAdam> {
};
template <typename T>
struct AdamFunctor<T, CPUAdam> {
class AdamFunctor<T, CPUAdam> {
private:
T beta1_;
T beta2_;
T epsilon_;
......@@ -130,6 +170,7 @@ struct AdamFunctor<T, CPUAdam> {
const T* param_;
T* param_out_;
public:
AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* lr, const T* grad, const T* param,
......@@ -179,10 +220,11 @@ struct AdamFunctor<T, CPUAdam> {
};
template <typename T, typename Flavour>
struct SparseAdamFunctor;
class SparseAdamFunctor;
template <typename T>
struct SparseAdamFunctor<T, GPUAdam> {
class SparseAdamFunctor<T, GPUAdam> {
private:
T beta1_;
T beta2_;
T epsilon_;
......@@ -203,6 +245,7 @@ struct SparseAdamFunctor<T, GPUAdam> {
int64_t row_count_;
bool lazy_mode_;
public:
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad,
......@@ -261,7 +304,8 @@ struct SparseAdamFunctor<T, GPUAdam> {
};
template <typename T>
struct SparseAdamFunctor<T, CPUAdam> {
class SparseAdamFunctor<T, CPUAdam> {
private:
T beta1_;
T beta2_;
T epsilon_;
......@@ -281,6 +325,7 @@ struct SparseAdamFunctor<T, CPUAdam> {
int64_t row_numel_;
int64_t row_count_;
public:
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad,
......@@ -397,6 +442,10 @@ class AdamOpKernel : public framework::OpKernel<T> {
Ref(ctx.Output<LoDTensor>("Moment1Out"), "Must set Moment1Out");
auto& mom2_out =
Ref(ctx.Output<LoDTensor>("Moment2Out"), "Must set Moment1Out");
auto& beta1_pow_out =
Ref(ctx.Output<LoDTensor>("Beta1PowOut"), "Must set Beta1PowOut");
auto& beta2_pow_out =
Ref(ctx.Output<LoDTensor>("Beta2PowOut"), "Must set Beta2PowOut");
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
......@@ -408,6 +457,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow.numel()
<< "beta2_pow.numel() : " << beta2_pow.numel();
VLOG(3) << "param.numel(): " << param.numel();
BetaPowFunctor<T> beta_functor(
beta1, beta2, beta1_pow.template data<T>(),
beta2_pow.template data<T>(),
beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta2_pow_out.template mutable_data<T>(ctx.GetPlace()));
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
......@@ -423,6 +480,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()));
functor(param.numel());
beta_functor.apply_update(beta2_pow.numel());
} else if (platform::is_gpu_place(ctx.GetPlace())) {
AdamFunctor<T, GPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
......@@ -433,11 +491,16 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr.template data<T>(), grad.template data<T>(),
param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()));
// update param and moment
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
for_range(functor);
// update beta1 and beta2
platform::ForRange<DeviceContext> for_range_beta(
static_cast<const DeviceContext&>(ctx.device_context()),
beta2_pow.numel());
for_range_beta(beta_functor);
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad =
......@@ -485,6 +548,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode);
// update beta1 and beta2
beta_functor.apply_update(beta2_pow.numel());
if (lazy_mode) {
VLOG(3) << "run cpu lazy mode";
size_t row_count = grad_merge.rows().size();
......@@ -574,6 +639,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
for_range(functor);
// update beta1 and beta2
platform::ForRange<DeviceContext> for_range_beta(
static_cast<const DeviceContext&>(ctx.device_context()),
beta2_pow.numel());
for_range_beta(beta_functor);
}
} else {
PADDLE_THROW("Variable type not supported by adam_op");
......
......@@ -13,11 +13,111 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/lamb_op.h"
#include "paddle/fluid/operators/optimizers/adam_op.h"
namespace paddle {
namespace operators {
class LambOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
platform::errors::NotFound(
"Input(Param) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
platform::errors::NotFound(
"Input(Grad) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Moment1"), true,
platform::errors::NotFound(
"Input(Moment1) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Moment2"), true,
platform::errors::NotFound(
"Input(Moment2) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"), true,
platform::errors::NotFound(
"Input(LearningRate) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Beta1Pow"), true,
platform::errors::NotFound(
"Input(Beta1Pow) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Beta2Pow"), true,
platform::errors::NotFound(
"Input(Beta2Pow) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
platform::errors::NotFound(
"Output(ParamOut) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment1Out"), true,
platform::errors::NotFound(
"Output(Moment1Out) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment2Out"), true,
platform::errors::NotFound(
"Output(Moment2Out) of LambOp should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(
framework::product(lr_dims), 0,
platform::errors::InvalidArgument(
"The number of LearningRate shall not be 0, but received %d. Maybe "
"the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
framework::product(lr_dims)));
PADDLE_ENFORCE_EQ(
framework::product(lr_dims), 1,
platform::errors::InvalidArgument(
"Learning rate should have 1 dimension, but received %d.",
framework::product(lr_dims)));
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
PADDLE_ENFORCE_GE(framework::product(beta1_pow_dims), 1,
platform::errors::InvalidArgument(
"The size of Beta1 power accumulator should be "
"greater than 0, but received %d.",
framework::product(beta1_pow_dims)));
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
PADDLE_ENFORCE_GE(framework::product(beta2_pow_dims), 1,
platform::errors::InvalidArgument(
"The size of Beta2 power accumulator should be "
"greater than 0, but received %d.",
framework::product(beta2_pow_dims)));
auto param_dims = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"Param and Grad input of LambOp should have same dimension. But "
"received Param dims: [%s], Grad dims: [%s].",
param_dims, ctx->GetInputDim("Grad")));
}
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment1"),
platform::errors::InvalidArgument(
"Param and Moment1 input of LambOp should have same dimension. But "
"received Param dims: [%s], Moment1 dims: [%s].",
param_dims, ctx->GetInputDim("Moment1")));
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment2"),
platform::errors::InvalidArgument(
"Param and Moment2 input of LambOp should have same dimension. But "
"received Param dims: [%s], Moment2 dims: [%s].",
param_dims, ctx->GetInputDim("Moment2")));
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class LambOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -79,7 +179,7 @@ learning rate, $\lambda$ the weight decay rate.
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(lamb, ops::AdamOp, ops::LambOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(lamb, ops::LambOp, ops::LambOpMaker);
REGISTER_OP_CPU_KERNEL(
lamb, ops::LambOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LambOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -1683,7 +1683,9 @@ class AdamOptimizer(Optimizer):
outputs = {
"ParamOut": param_and_grad[0],
"Moment1Out": moment1,
"Moment2Out": moment2
"Moment2Out": moment2,
"Beta1PowOut": beta1_pow_acc,
"Beta2PowOut": beta2_pow_acc,
}
attrs = {
"epsilon": self._epsilon,
......@@ -1709,46 +1711,6 @@ class AdamOptimizer(Optimizer):
return adam_op
def _finish_update(self, block, param_and_grads):
"""Update Beta1 and Beta2 Power accumulators
"""
assert isinstance(block, framework.Block)
main_block = block.program.global_block()
for param, grad in param_and_grads:
if grad is None or param.trainable is False:
continue
with param.block.program._optimized_guard(
[param, grad]), name_scope("optimizer"):
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param)
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param)
inputs = {"X": beta1_pow_acc}
attrs = {}
if isinstance(self._beta1, Variable):
inputs['ScaleTensor'] = self._beta1
else:
attrs['scale'] = self._beta1
main_block.append_op(
type="scale",
inputs=inputs,
outputs={"Out": beta1_pow_acc},
attrs=attrs,
stop_gradient=True)
inputs = {"X": beta2_pow_acc}
attrs = {}
if isinstance(self._beta2, Variable):
inputs['ScaleTensor'] = self._beta2
else:
attrs['scale'] = self._beta2
main_block.append_op(
type="scale",
inputs=inputs,
outputs={"Out": beta2_pow_acc},
attrs=attrs,
stop_gradient=True)
class AdamaxOptimizer(Optimizer):
"""
......
......@@ -58,7 +58,9 @@ class TestAdamOp1(OpTest):
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out
'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
}
def test_check_output(self):
......@@ -101,7 +103,9 @@ class TestAdamOp2(OpTest):
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out
'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
}
def test_check_output(self):
......@@ -122,11 +126,11 @@ class TestAdamOpMultipleSteps(OpTest):
moment2 = np.random.random((102, 105)).astype("float32")
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.999
self.beta1 = 0.9
self.beta2 = 0.999
epsilon = 1e-8
beta1_pow = beta1**10
beta2_pow = beta2**10
self.beta1_pow = self.beta1**10
self.beta2_pow = self.beta2**10
self.inputs = {
'Param': param,
......@@ -134,21 +138,29 @@ class TestAdamOpMultipleSteps(OpTest):
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32")
'Beta1Pow': np.array([self.beta1_pow]).astype("float32"),
'Beta2Pow': np.array([self.beta2_pow]).astype("float32")
}
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
self.attrs = {
'epsilon': epsilon,
'beta1': self.beta1,
'beta2': self.beta2
}
def test_check_output(self):
for _ in range(self.num_steps):
param_out, moment1_out, \
moment2_out = adam_step(self.inputs, self.attrs)
beta1_pow_out = self.inputs['Beta1Pow'] * self.beta1
beta2_pow_out = self.inputs['Beta2Pow'] * self.beta2
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out
'ParamOut': param_out,
'Beta1PowOut': beta1_pow_out,
'Beta2PowOut': beta2_pow_out
}
# Verify output for this step
......@@ -160,8 +172,8 @@ class TestAdamOpMultipleSteps(OpTest):
self.inputs['Moment2'] = moment2_out
# Update powers of Beta1 and Beta2 for next time step
self.inputs['Beta1Pow'] *= self.attrs['beta1']
self.inputs['Beta2Pow'] *= self.attrs['beta1']
self.inputs['Beta1Pow'] = beta1_pow_out
self.inputs['Beta2Pow'] = beta2_pow_out
# Randomize gradient for next step
self.inputs['Grad'] = np.random.uniform(
......@@ -254,6 +266,8 @@ class TestSparseAdamOp(unittest.TestCase):
beta1 = 0.78
beta2 = 0.836
epsilon = 1e-4
beta1_pow = np.array([beta1**10]).astype("float32")
beta2_pow = np.array([beta2**10]).astype("float32")
height = 10
rows = [0, 4, 7]
......@@ -264,8 +278,8 @@ class TestSparseAdamOp(unittest.TestCase):
"Param": np.full((height, row_numel), 5.0).astype("float32"),
"Moment1": np.full((height, row_numel), 5.0).astype("float32"),
"Moment2": np.full((height, row_numel), 5.0).astype("float32"),
'Beta1Pow': np.array([beta1**10]).astype("float32"),
'Beta2Pow': np.array([beta2**10]).astype("float32"),
'Beta1Pow': beta1_pow,
'Beta2Pow': beta2_pow,
"LearningRate": np.full((1), 2.0).astype("float32")
}
self.init_output = np.full((height, row_numel), 0.0).astype("float32")
......@@ -294,7 +308,9 @@ class TestSparseAdamOp(unittest.TestCase):
self.outputs = {
"ParamOut": param_out,
"Moment1Out": mom1,
"Moment2Out": mom2
"Moment2Out": mom2,
'Beta1PowOut': beta1_pow * beta1,
'Beta2PowOut': beta2_pow * beta2
}
def check_with_place(self, place, lazy_mode):
......@@ -376,7 +392,9 @@ class TestAdamOpBetaVariable(OpTest):
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out
'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
}
def test_check_output(self):
......
......@@ -320,9 +320,8 @@ class TestAdamOptimizer(unittest.TestCase):
self.assertEqual(len(adam_optimizer.get_accumulators()), 0)
with framework.program_guard(program, init_program):
opts = adam_optimizer.apply_gradients(params_grads)
self.assertEqual(len(opts), 4)
self.assertEqual([op.type for op in opts],
["scale", "adam", "scale", "scale"])
self.assertEqual(len(opts), 2)
self.assertEqual([op.type for op in opts], ["scale", "adam"])
# Check accumulators
accumulators = adam_optimizer.get_accumulators()
......
......@@ -68,7 +68,7 @@ class TestTrainable(unittest.TestCase):
test_trainable,
feed_dict,
op_count={'adam': 1,
'scale': 2,
'scale': 0,
'mul_grad': 0})
self.check_trainable(
test_trainable,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册