目录
Outline Save/load weights Save/load entire model saved_model
Outline
save/load weights ?# 记录部分信息
save/load entire model ?# 记录所有信息
saved_model ?# 通用,包括Pytorch、其他语言
Save/load weights
保存部分信息#?Save?the?weights model.save_weights('./checkpoints/my_checkpoint') #?Restore?the?weights model?=?create_model() model.load_weights('./checkpoints/my_checkpoint') loss,?acc?=?model.evaluate(test_images,?test_labels) print(f'Restored?model,?accuracy:?{100*acc:5.2f}')
import?tensorflow?as?tf from?tensorflow.keras?import?datasets,?layers,?optimizers,?Sequential,?metrics def?preprocess(x,?y): ????""" ????x?is?a?simple?image,?not?a?batch ????""" ????x?=?tf.cast(x,?dtype=tf.float32)?/?255. ????x?=?tf.reshape(x,?[28?*?28]) ????y?=?tf.cast(y,?dtype=tf.int32) ????y?=?tf.one_hot(y,?depth=10) ????return?x,?y batchsz?=?128 (x,?y),?(x_val,?y_val)?=?datasets.mnist.load_data() print('datasets:',?x.shape,?y.shape,?x.min(),?x.max()) db?=?tf.data.Dataset.from_tensor_slices((x,?y)) db?=?db.map(preprocess).shuffle(60000).batch(batchsz) ds_val?=?tf.data.Dataset.from_tensor_slices((x_val,?y_val)) ds_val?=?ds_val.map(preprocess).batch(batchsz) sample?=?next(iter(db)) print(sample[0].shape,?sample[1].shape) network?=?Sequential([ ????layers.Dense(256,?activation='relu'), ????layers.Dense(128,?activation='relu'), ????layers.Dense(64,?activation='relu'), ????layers.Dense(32,?activation='relu'), ????layers.Dense(10) ]) network.build(input_shape=(None,?28?*?28)) network.summary() network测试数据pile(optimizer=optimizers.Adam(lr=0.01), ????????????????loss=tf.losses.CategoricalCrossentropy(from_logits=True), ????????????????metrics=['accuracy']) network.fit(db,?epochs=3,?validation_data=ds_val,?validation_freq=2) network.evaluate(ds_val) network.save_weights('weights.ckpt') print('saved?weights.') del?network network?=?Sequential([ ????layers.Dense(256,?activation='relu'), ????layers.Dense(128,?activation='relu'), ????layers.Dense(64,?activation='relu'), ????layers.Dense(32,?activation='relu'), ????layers.Dense(10) ]) network测试数据pile(optimizer=optimizers.Adam(lr=0.01), ????????????????loss=tf.losses.CategoricalCrossentropy(from_logits=True), ????????????????metrics=['accuracy']) network.load_weights('weights.ckpt') print('loaded?weights!') network.evaluate(ds_val)
datasets:?(60000,?28,?28)?(60000,)?0?255 (128,?784)?(128,?10) Model:?"sequential" _________________________________________________________________ Layer?(type)?????????????????Output?Shape??????????????Param?#??? ================================================================= dense?(Dense)????????????????multiple??????????????????200960???? _________________________________________________________________ dense_1?(Dense)??????????????multiple??????????????????32896????? _________________________________________________________________ dense_2?(Dense)??????????????multiple??????????????????8256?????? _________________________________________________________________ dense_3?(Dense)??????????????multiple??????????????????2080?????? _________________________________________________________________ dense_4?(Dense)??????????????multiple??????????????????330??????? ================================================================= Total?params:?244,522 Trainable?params:?244,522 Non-trainable?params:?0 _________________________________________________________________ Epoch?1/3 469/469?[==============================]?-?5s?12ms/step?-?loss:?0.2876?-?accuracy:?0.8335 Epoch?2/3 469/469?[==============================]?-?5s?11ms/step?-?loss:?0.1430?-?accuracy:?0.9551?-?val_loss:?0.1397?-?val_accuracy:?0.9634 Epoch?3/3 469/469?[==============================]?-?4s?9ms/step?-?loss:?0.1155?-?accuracy:?0.9681 79/79?[==============================]?-?1s?8ms/step?-?loss:?0.1344?-?accuracy:?0.9654 saved?weights. loaded?weights! 79/79?[==============================]?-?1s?13ms/step?-?loss:?0.1344?-?accuracy:?0.9593 [0.13439734456132318,?0.9654]
Save/load entire model
完美保存所有信息network.save('model.h5') print('saved?total?model.') del?network print('load?model?from?file') network?=?tf.keras.models.load_model('model.h5') network.evaluate(x_val,?y_val)
?import?tensorflow?as?tf from?tensorflow.keras?import?datasets,?layers,?optimizers,?Sequential,?metrics def?preprocess(x,?y): ????""" ????x?is?a?simple?image,?not?a?batch ????""" ????x?=?tf.cast(x,?dtype=tf.float32)?/?255. ????x?=?tf.reshape(x,?[28?*?28]) ????y?=?tf.cast(y,?dtype=tf.int32) ????y?=?tf.one_hot(y,?depth=10) ????return?x,?y batchsz?=?128 (x,?y),?(x_val,?y_val)?=?datasets.mnist.load_data() print('datasets:',?x.shape,?y.shape,?x.min(),?x.max()) db?=?tf.data.Dataset.from_tensor_slices((x,?y)) db?=?db.map(preprocess).shuffle(60000).batch(batchsz) ds_val?=?tf.data.Dataset.from_tensor_slices((x_val,?y_val)) ds_val?=?ds_val.map(preprocess).batch(batchsz) sample?=?next(iter(db)) print(sample[0].shape,?sample[1].shape) network?=?Sequential([ ????layers.Dense(256,?activation='relu'), ????layers.Dense(128,?activation='relu'), ????layers.Dense(64,?activation='relu'), ????layers.Dense(32,?activation='relu'), ????layers.Dense(10) ]) network.build(input_shape=(None,?28?*?28)) network.summary() network测试数据pile(optimizer=optimizers.Adam(lr=0.01), ????????????????loss=tf.losses.CategoricalCrossentropy(from_logits=True), ????????????????metrics=['accuracy']) network.fit(db,?epochs=3,?validation_data=ds_val,?validation_freq=2) network.evaluate(ds_val) network.save('model.h5') print('saved?total?model.') del?network print('load?model?from?file') network1?=?tf.keras.models.load_model('model.h5') network1测试数据pile(optimizer=optimizers.Adam(lr=0.01), ?????????????????loss=tf.losses.CategoricalCrossentropy(from_logits=True), ?????????????????metrics=['accuracy']) x_val?=?tf.cast(x_val,?dtype=tf.float32)?/?255. x_val?=?tf.reshape(x_val,?[-1,?28?*?28]) y_val?=?tf.cast(y_val,?dtype=tf.int32) y_val?=?tf.one_hot(y_val,?depth=10) ds_val?=?tf.data.Dataset.from_tensor_slices((x_val,?y_val)).batch(128) network1.evaluate(ds_val)
datasets:?(60000,?28,?28)?(60000,)?0?255 (128,?784)?(128,?10) Model:?"sequential_4" _________________________________________________________________ Layer?(type)?????????????????Output?Shape??????????????Param?#??? ================================================================= dense_20?(Dense)?????????????multiple??????????????????200960???? _________________________________________________________________ dense_21?(Dense)?????????????multiple??????????????????32896????? _________________________________________________________________ dense_22?(Dense)?????????????multiple??????????????????8256?????? _________________________________________________________________ dense_23?(Dense)?????????????multiple??????????????????2080?????? _________________________________________________________________ dense_24?(Dense)?????????????multiple??????????????????330??????? ================================================================= Total?params:?244,522 Trainable?params:?244,522 Non-trainable?params:?0 _________________________________________________________________ Epoch?1/3 469/469?[==============================]?-?6s?13ms/step?-?loss:?0.2851?-?accuracy:?0.8405 Epoch?2/3 469/469?[==============================]?-?6s?13ms/step?-?loss:?0.1365?-?accuracy:?0.9580?-?val_loss:?0.1422?-?val_accuracy:?0.9590 Epoch?3/3 469/469?[==============================]?-?5s?11ms/step?-?loss:?0.1130?-?accuracy:?0.9661 79/79?[==============================]?-?1s?10ms/step?-?loss:?0.1201?-?accuracy:?0.9714 saved?total?model. load?model?from?file W0525?16:44:50.178785?4587234752?hdf5_format.py:266]?Sequential?models?without?an?`input_shape`?passed?to?the?first?layer?cannot?reload?their?optimizer?state.?As?a?result,?your?model?isstarting?with?a?freshly?initialized?optimizer. 79/79?[==============================]?-?1s?7ms/step?-?loss:?0.1201?-?accuracy:?0.9672 [0.12005392337660747,?0.9714]
saved_model
通用,包括Pytorch、其他语言
用于工业环境的部署
tf.saved_model.save(m,?'/tmp/saved_model/') imported?=?tf.saved_model.load(path) f?=?imported.signatures['serving_default'] print(f(x=tf.ones([1,?28,?28,?3])))
声明:本文来自网络,不代表【好得很程序员自学网】立场,转载请注明出处:http://haodehen.cn/did127437