Python TensorFlow: How to restart training with optimizer and import_meta_graph? -


i'm trying restart model training in tensorflow picking left off. i'd use added (0.12+ think) import_meta_graph() not reconstruct graph.

i've seen solutions this, e.g. tensorflow: how save/restore model?, run issues adamoptimizer, valueerror: cannot add op name <my weights variable name>/adam name used error. this can fixed initializing, model values cleared!

there other answers , full examples out there, seem older , don't include newer import_meta_graph() approach, or don't have non-tensor optimizer. closest question find tensorflow: saving , restoring session there no final clear cut solution , example pretty complicated.

ideally i'd simple run-able example starting scratch, stopping, picking again. have works (below), wonder if i'm missing something. surely i'm not 1 doing this?

here came reading docs, other similar solutions, , trial , error. it's simple autoencoder on random data. if ran, ran again, continue left off (i.e. cost function on first run goes ~0.5 -> 0.3 second run starts ~0.3). unless missed something, of saving, constructors, model building, add_to_collection there needed , in precise order, there may simpler way.

and yes, loading graph import_meta_graph isn't needed here since code right above, want in actual application.

from __future__ import print_function import tensorflow tf import os import math import numpy np  output_dir = "/root/data/temp" model_checkpoint_file_base = os.path.join(output_dir, "model.ckpt")  input_length = 10 encoded_length = 3 learning_rate = 0.001 n_epochs = 10 n_batches = 10 if not os.path.exists(model_checkpoint_file_base + ".meta"):     print("making new")     brand_new = true      x_in = tf.placeholder(tf.float32, [none, input_length], name="x_in")     w_enc = tf.variable(tf.random_uniform([input_length, encoded_length],                                           -1.0 / math.sqrt(input_length),                                           1.0 / math.sqrt(input_length)), name="w_enc")     b_enc = tf.variable(tf.zeros(encoded_length), name="b_enc")     encoded = tf.nn.tanh(tf.matmul(x_in, w_enc) + b_enc, name="encoded")     w_dec = tf.transpose(w_enc, name="w_dec")     b_dec = tf.variable(tf.zeros(input_length), name="b_dec")     decoded = tf.nn.tanh(tf.matmul(encoded, w_dec) + b_dec, name="decoded")     cost = tf.sqrt(tf.reduce_mean(tf.square(decoded - x_in)), name="cost")      saver = tf.train.saver() else:     print("reloading existing")     brand_new = false     saver = tf.train.import_meta_graph(model_checkpoint_file_base + ".meta")     g = tf.get_default_graph()     x_in = g.get_tensor_by_name("x_in:0")     cost = g.get_tensor_by_name("cost:0")   sess = tf.session() if brand_new:     optimizer = tf.train.adamoptimizer(learning_rate).minimize(cost)     init = tf.global_variables_initializer()     sess.run(init)     tf.add_to_collection("optimizer", optimizer) else:     saver.restore(sess, model_checkpoint_file_base)     optimizer = tf.get_collection("optimizer")[0]  epoch_i in range(n_epochs):     batch in range(n_batches):         batch = np.random.rand(50, input_length)         _, curr_cost = sess.run([optimizer, cost], feed_dict={x_in: batch})         print("batch_cost:", curr_cost)         save_path = tf.train.saver().save(sess, model_checkpoint_file_base) 

Comments