好得很程序员自学网

<tfoot draggable='sEl'></tfoot>

模型加载与保存

目录

Outline Save/load weights Save/load entire model saved_model

Outline

save/load weights ?# 记录部分信息

save/load entire model ?# 记录所有信息

saved_model ?# 通用,包括Pytorch、其他语言

Save/load weights

保存部分信息
#?Save?the?weights
model.save_weights('./checkpoints/my_checkpoint')

#?Restore?the?weights
model?=?create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,?acc?=?model.evaluate(test_images,?test_labels)
print(f'Restored?model,?accuracy:?{100*acc:5.2f}')
import?tensorflow?as?tf
from?tensorflow.keras?import?datasets,?layers,?optimizers,?Sequential,?metrics


def?preprocess(x,?y):
????"""
????x?is?a?simple?image,?not?a?batch
????"""
????x?=?tf.cast(x,?dtype=tf.float32)?/?255.
????x?=?tf.reshape(x,?[28?*?28])
????y?=?tf.cast(y,?dtype=tf.int32)
????y?=?tf.one_hot(y,?depth=10)
????return?x,?y


batchsz?=?128
(x,?y),?(x_val,?y_val)?=?datasets.mnist.load_data()
print('datasets:',?x.shape,?y.shape,?x.min(),?x.max())

db?=?tf.data.Dataset.from_tensor_slices((x,?y))
db?=?db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val?=?tf.data.Dataset.from_tensor_slices((x_val,?y_val))
ds_val?=?ds_val.map(preprocess).batch(batchsz)

sample?=?next(iter(db))
print(sample[0].shape,?sample[1].shape)

network?=?Sequential([
????layers.Dense(256,?activation='relu'),
????layers.Dense(128,?activation='relu'),
????layers.Dense(64,?activation='relu'),
????layers.Dense(32,?activation='relu'),
????layers.Dense(10)
])
network.build(input_shape=(None,?28?*?28))
network.summary()

network测试数据pile(optimizer=optimizers.Adam(lr=0.01),
????????????????loss=tf.losses.CategoricalCrossentropy(from_logits=True),
????????????????metrics=['accuracy'])

network.fit(db,?epochs=3,?validation_data=ds_val,?validation_freq=2)

network.evaluate(ds_val)

network.save_weights('weights.ckpt')
print('saved?weights.')
del?network

network?=?Sequential([
????layers.Dense(256,?activation='relu'),
????layers.Dense(128,?activation='relu'),
????layers.Dense(64,?activation='relu'),
????layers.Dense(32,?activation='relu'),
????layers.Dense(10)
])
network测试数据pile(optimizer=optimizers.Adam(lr=0.01),
????????????????loss=tf.losses.CategoricalCrossentropy(from_logits=True),
????????????????metrics=['accuracy'])
network.load_weights('weights.ckpt')
print('loaded?weights!')
network.evaluate(ds_val)
datasets:?(60000,?28,?28)?(60000,)?0?255
(128,?784)?(128,?10)
Model:?"sequential"
_________________________________________________________________
Layer?(type)?????????????????Output?Shape??????????????Param?#???
=================================================================
dense?(Dense)????????????????multiple??????????????????200960????
_________________________________________________________________
dense_1?(Dense)??????????????multiple??????????????????32896?????
_________________________________________________________________
dense_2?(Dense)??????????????multiple??????????????????8256??????
_________________________________________________________________
dense_3?(Dense)??????????????multiple??????????????????2080??????
_________________________________________________________________
dense_4?(Dense)??????????????multiple??????????????????330???????
=================================================================
Total?params:?244,522
Trainable?params:?244,522
Non-trainable?params:?0
_________________________________________________________________
Epoch?1/3
469/469?[==============================]?-?5s?12ms/step?-?loss:?0.2876?-?accuracy:?0.8335
Epoch?2/3
469/469?[==============================]?-?5s?11ms/step?-?loss:?0.1430?-?accuracy:?0.9551?-?val_loss:?0.1397?-?val_accuracy:?0.9634
Epoch?3/3
469/469?[==============================]?-?4s?9ms/step?-?loss:?0.1155?-?accuracy:?0.9681
79/79?[==============================]?-?1s?8ms/step?-?loss:?0.1344?-?accuracy:?0.9654
saved?weights.
loaded?weights!
79/79?[==============================]?-?1s?13ms/step?-?loss:?0.1344?-?accuracy:?0.9593





[0.13439734456132318,?0.9654]

Save/load entire model

完美保存所有信息
network.save('model.h5')
print('saved?total?model.')
del?network

print('load?model?from?file')
network?=?tf.keras.models.load_model('model.h5')

network.evaluate(x_val,?y_val)
?import?tensorflow?as?tf
from?tensorflow.keras?import?datasets,?layers,?optimizers,?Sequential,?metrics


def?preprocess(x,?y):
????"""
????x?is?a?simple?image,?not?a?batch
????"""
????x?=?tf.cast(x,?dtype=tf.float32)?/?255.
????x?=?tf.reshape(x,?[28?*?28])
????y?=?tf.cast(y,?dtype=tf.int32)
????y?=?tf.one_hot(y,?depth=10)
????return?x,?y


batchsz?=?128
(x,?y),?(x_val,?y_val)?=?datasets.mnist.load_data()
print('datasets:',?x.shape,?y.shape,?x.min(),?x.max())

db?=?tf.data.Dataset.from_tensor_slices((x,?y))
db?=?db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val?=?tf.data.Dataset.from_tensor_slices((x_val,?y_val))
ds_val?=?ds_val.map(preprocess).batch(batchsz)

sample?=?next(iter(db))
print(sample[0].shape,?sample[1].shape)

network?=?Sequential([
????layers.Dense(256,?activation='relu'),
????layers.Dense(128,?activation='relu'),
????layers.Dense(64,?activation='relu'),
????layers.Dense(32,?activation='relu'),
????layers.Dense(10)
])
network.build(input_shape=(None,?28?*?28))
network.summary()

network测试数据pile(optimizer=optimizers.Adam(lr=0.01),
????????????????loss=tf.losses.CategoricalCrossentropy(from_logits=True),
????????????????metrics=['accuracy'])

network.fit(db,?epochs=3,?validation_data=ds_val,?validation_freq=2)

network.evaluate(ds_val)

network.save('model.h5')
print('saved?total?model.')
del?network

print('load?model?from?file')

network1?=?tf.keras.models.load_model('model.h5')
network1测试数据pile(optimizer=optimizers.Adam(lr=0.01),
?????????????????loss=tf.losses.CategoricalCrossentropy(from_logits=True),
?????????????????metrics=['accuracy'])
x_val?=?tf.cast(x_val,?dtype=tf.float32)?/?255.
x_val?=?tf.reshape(x_val,?[-1,?28?*?28])
y_val?=?tf.cast(y_val,?dtype=tf.int32)
y_val?=?tf.one_hot(y_val,?depth=10)
ds_val?=?tf.data.Dataset.from_tensor_slices((x_val,?y_val)).batch(128)
network1.evaluate(ds_val)
datasets:?(60000,?28,?28)?(60000,)?0?255
(128,?784)?(128,?10)
Model:?"sequential_4"
_________________________________________________________________
Layer?(type)?????????????????Output?Shape??????????????Param?#???
=================================================================
dense_20?(Dense)?????????????multiple??????????????????200960????
_________________________________________________________________
dense_21?(Dense)?????????????multiple??????????????????32896?????
_________________________________________________________________
dense_22?(Dense)?????????????multiple??????????????????8256??????
_________________________________________________________________
dense_23?(Dense)?????????????multiple??????????????????2080??????
_________________________________________________________________
dense_24?(Dense)?????????????multiple??????????????????330???????
=================================================================
Total?params:?244,522
Trainable?params:?244,522
Non-trainable?params:?0
_________________________________________________________________
Epoch?1/3
469/469?[==============================]?-?6s?13ms/step?-?loss:?0.2851?-?accuracy:?0.8405
Epoch?2/3
469/469?[==============================]?-?6s?13ms/step?-?loss:?0.1365?-?accuracy:?0.9580?-?val_loss:?0.1422?-?val_accuracy:?0.9590
Epoch?3/3
469/469?[==============================]?-?5s?11ms/step?-?loss:?0.1130?-?accuracy:?0.9661
79/79?[==============================]?-?1s?10ms/step?-?loss:?0.1201?-?accuracy:?0.9714
saved?total?model.
load?model?from?file


W0525?16:44:50.178785?4587234752?hdf5_format.py:266]?Sequential?models?without?an?`input_shape`?passed?to?the?first?layer?cannot?reload?their?optimizer?state.?As?a?result,?your?model?isstarting?with?a?freshly?initialized?optimizer.


79/79?[==============================]?-?1s?7ms/step?-?loss:?0.1201?-?accuracy:?0.9672





[0.12005392337660747,?0.9714]

saved_model

通用,包括Pytorch、其他语言

用于工业环境的部署

tf.saved_model.save(m,?'/tmp/saved_model/')

imported?=?tf.saved_model.load(path)
f?=?imported.signatures['serving_default']
print(f(x=tf.ones([1,?28,?28,?3])))

查看更多关于模型加载与保存的详细内容...

  阅读:34次