快速浏览Apache MXNet 1.2中的Swish激活函数
来源:不靠谱的猫     阅读:671
黑蚁网络
发布于 2018-07-09 22:19
查看主页

Apache MXNet 1.2是即将发布。正如更改日志所暗示的,这看起来像是一个主要版本,但在本文中,我们将重点详情一个新的激活函数:Swish。

快速回顾激活函数

在深度神经网络中,激活函数的目的是引入非线性,即对神经元输出执行非线性决策阈值。从某种意义上说,我们试图模仿一种简单的方式,无疑是生物神经元的行为。

随着时间的推移,许多激活函数被设计出来,每一个新函数都试图克服其前辈的缺点。例如, Rectified Linear Unit function(又名ReLU)通过求解消失梯度问题对Sigmoid函数进行了改进。

当然,更好的激活函数的竞赛从未中止。2017年末,发现了一个新函数:Swish。

Swish函数

通过自动组合不同的数学运算符,Prajit Ramachandran,Barret Zoph和Quoc V评估了大量候选激活函数的性可以。其中一个,他们称之为Swish,结果比其余好。

快速浏览Apache MXNet 1.2中的Swish激活函数

f(x)=x?sigmoid(βx)

正如你所看到的,假如β参数小,Swish是接近线性,当β参数很小,接近ReLU时大。β的最佳点似乎在1到2之间:它创立一个non-monotonic“bump”对于负似乎有趣的属性。

正如作者所强调的那样:“ 简单地使用Swish单元代替ReLUs,对于ImageNet来说,移动NASNet-A和Inception-ResNet-v2的分类精度分别提高了0.9%和0.6%。

这听起来像是一个简单的改进,不是吗?让我们在MXNet上测试它!

Swish in MXNet

Swish可使用于MXNet 1.2中的Gluon API。它在incubator-mxnet / python / mxnet / gluon / nn / activations.py中定义,并且在我们的Gluon代码中用它很容易:nn.Swish()。

为了评估它的性可以,我们将在CIFAR-10上训练VGG16卷积神经网络的两个不同版本:

  • 在Gluon model zoo中用批量标准化的VGG16 ,即便使用ReLU(incubator-mxnet / python / mxnet / gluon / model_zoo / vision / vgg.py)

  • 相同的网络被修改为用Swish作为卷积层和全连接层。

这非常简单:从主分支开始,我们只要创立一个vggswish.py文件并使用Swish代替ReLU,例如:

self.features.add(nn.Dense(4096,activation ='relu',weight_initializer ='normal',bias_initializer ='zeros'))

--->

self.features.add(nn.Dense(4096,weight_initializer ='normal',bias_initializer ='zeros'))

self.features.add(nn.Swish())

而后,我们将这组新模型插入到incubator-mxnet / python / mxnet / gluon / model_zoo / __ init__.py和voila!

Training on CIFAR-10

MXNet包含一个图像分类脚本,可让我们用各种网络体系结构和数据集(incubator-mxnet / example / gluon / image_classification.py)进行训练。

我们将用SGD进行时代步骤,每次将学习率除以10。

python image_classification.py --model vgg16_bn --batch-size = 128 --lr = 0.1 --lr-steps = '10,20,30'--epochs = 40

python image_classification.py --model vgg16_bn_swish --batch-size = 128 --lr = 0.1 --lr-steps = '10,20,30'--epochs = 40

结果如下。

快速浏览Apache MXNet 1.2中的Swish激活函数

VGG16与ReLU与VGG16与Swish

从一个单一的例子中得出结论总是困难的(虽然我的确做了很多不同的训练,结果是一致的)。在这里,我们能看到比Swish版本似乎更快的训练和验证以及ReLU版本。

最高验证精度分别为14次迭代0.866186和19次迭代0.866001。差别很小,但是VGG16并不是一个非常深的网络。

免责声明:本文为用户发表,不代表网站立场,仅供参考,不构成引导等用途。 系统环境 软件环境
相关推荐
第30届中国控制与决策会议在沈阳胜利召开
聊聊storm的GraphiteStormReporter
简述Nginx在windows下的配置及部署
图像检索---索引技术
Linux ping/netstat参数妙使用
首页
搜索
订单
购物车
我的