keras源码engine中toplogy.py定义了加载权重的函数:
load_weights(self, filepath, by_name=False)
其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16
VGG19/resnet50等,源码描述如下:
If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.
若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了
模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:
If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.
在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用
model.load_weights(filepath,by_name=True)
补充知识:Keras下实现mnist手写数字
之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,
代码如下。
import struct import numpy as np import os import keras from keras.models import Sequential from keras.layers import Dense from keras.optimizers import SGD def load_mnist(path, kind): labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind) images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind) with open(labels_path, 'rb') as lbpath: magic, n = struct.unpack('>II', lbpath.read(8)) labels = np.fromfile(lbpath, dtype=np.uint8) with open(images_path, 'rb') as imgpath: magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16)) images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784 return images, labels #loading train and test data X_train, Y_train = load_mnist('.\\data', kind='train') X_test, Y_test = load_mnist('.\\data', kind='t10k') #turn labels to one_hot code Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10) #define models model = Sequential() model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh')) model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh')) model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True) model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"]) #start training model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3) #count accuracy y_train_pred = model.predict_classes(X_train, verbose=0) train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] print('Training accuracy: %.2f%%' % (train_acc * 100)) y_test_pred = model.predict_classes(X_test, verbose=0) test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] print('Test accuracy: %.2f%%' % (test_acc * 100))
训练结果如下:
Epoch 45/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323 Epoch 46/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358 Epoch 47/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347 Epoch 48/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350 Epoch 49/50 42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359 Epoch 50/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346 Training accuracy: 94.11% Test accuracy: 93.61%
以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
keras,导入,weights
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
RTX 5090要首发 性能要翻倍!三星展示GDDR7显存
三星在GTC上展示了专为下一代游戏GPU设计的GDDR7内存。
首次推出的GDDR7内存模块密度为16GB,每个模块容量为2GB。其速度预设为32 Gbps(PAM3),但也可以降至28 Gbps,以提高产量和初始阶段的整体性能和成本效益。
据三星表示,GDDR7内存的能效将提高20%,同时工作电压仅为1.1V,低于标准的1.2V。通过采用更新的封装材料和优化的电路设计,使得在高速运行时的发热量降低,GDDR7的热阻比GDDR6降低了70%。
更新日志
- 魔兽世界奥卡兹岛地牢入口在哪里 奥卡兹岛地牢入口位置一览
- 和文军-丽江礼物[2007]FLAC
- 陈随意2012-今生的伴[豪记][WAV+CUE]
- 罗百吉.2018-我们都一样【乾坤唱片】【WAV+CUE】
- 《怪物猎人:荒野》不加中配请愿书引热议:跪久站不起来了?
- 《龙腾世纪4》IGN 9分!殿堂级RPG作品
- Twitch新规禁止皮套外露敏感部位 主播直接“真身”出镜
- 木吉他.1994-木吉他作品全集【滚石】【WAV+CUE】
- 莫华伦.2022-一起走过的日子【京文】【WAV+CUE】
- 曾淑勤.1989-装在袋子里的回忆【点将】【WAV+CUE】
- 滚石香港黄金十年系列《赵传精选》首版[WAV+CUE][1.1G]
- 雷婷《乡村情歌·清新民谣》1:1母盘直刻[低速原抓WAV+CUE][1.1G]
- 群星 《DJ夜色魅影HQⅡ》天艺唱片[WAV+CUE][1.1G]
- 群星《烧透你的耳朵2》DXD金佰利 [低速原抓WAV+CUE][1.3G]
- 群星《难忘的回忆精选4》宝丽金2CD[WAV+CUE][1.4G]