python - Tensorflow - Prediction output dependent on batch size -
i implemented generative adversarial network in tensorflow. during test time generated images quite if generate them using same batch_size used during training (64); if generate 1 image @ time result horrible.
the possible causes may 2:
- batch normalization?
- wrong usage of tf.shape dynamic batch size
here code:
from tensorflow.contrib.layers.python.layers import batch_norm def conc(x, y): """concatenate conditioning vector on feature map axis.""" x_shapes = x.get_shape() y_shapes = y.get_shape() x0 = tf.shape(x)[0] x1 = x_shapes[1].value x2 = x_shapes[2].value y3 = y_shapes[3].value return tf.concat([x, y * tf.ones(shape=(x0,x1,x2,y3))], 3) def batch_normal(input, scope="scope", reuse=false): return batch_norm(input, epsilon=1e-5, decay=0.9, scale=true, scope=scope, reuse=reuse, updates_collections=none) def generator(z_var, y): y_dim = y.get_shape()[1].value z_var = tf.concat([z_var, y], 1) d1 = tf.layers.dense(z_var, 1024, kernel_initializer=tf.random_normal_initializer(stddev=0.02), name='gen_fc1') d1 = tf.nn.relu(batch_normal(d1, scope='gen_bn1')) # add second layer d1 = tf.concat([d1, y], 1) d2 = tf.layers.dense(d1, 7 * 7 * 128, kernel_initializer=tf.random_normal_initializer(stddev=0.02), name='gen_fc2') d2 = tf.nn.relu(batch_normal(d2, scope='gen_bn2')) d2 = tf.reshape(d2, [-1, 7, 7, 128]) y = tf.reshape(y, shape=[-1, 1, 1, y_dim]) d2 = conc(d2, y) deconv1 = tf.layers.conv2d_transpose(d2, 64, (4, 4), strides=(2, 2), padding='same', kernel_initializer=tf.random_normal_initializer(stddev=0.02), name='gen_deconv1') d3 = tf.nn.relu(batch_normal(deconv1, scope='gen_bn3')) d3 = conc(d3, y) deconv2 = tf.layers.conv2d_transpose(d3, 1, (4, 4), strides=(2, 2), padding='same', kernel_initializer=tf.random_normal_initializer(stddev=0.02), name='gen_deconv2') return tf.nn.sigmoid(deconv2)
you may have other bugs, batch normalization big issue here.
batch normalization computes mean , variance of variables @ each layer in order normalization. meant proxy real mean , variance of variables, meaning mean , variance estimated on complete population instead of subset (the mini-batch). if mini-batch large enough approximated mean , variance close enough real ones, if have single example in mini-batch, estimation of mean , variance catastrophic.
what done fix after training done, compute mean , variance of model variables large subset on inputs (larger mini-batch). (somehow) plug values in batch normalization layers , turn off computation of mean , variance mini-batch. non-trivial do, assume whatever library using can deal this. if library can't deal this, useless since trained model never used (unless evaluate on mini-batches did).
i found tutorial online after quick search. might deprecated , there might better ones.
Comments
Post a Comment