python 3.x - tensorflow hidden state doesn't seem to change -
i'm trying follow along rnn tutorial on medium, refactoring go along. when run code, appears work, when tried print out current state
variable see what's happening inside neural network, got 1
s. expected behavior? state not being updated reason? understand, current state
should contain latest values in hidden layer batches, shouldn't 1
s. highly appreciated.
def __train_minibatch__(self, batch_num, sess, current_state): """ trains 1 minibatch. :type batch_num: int :param batch_num: current batch number. :type sess: tensorflow session :param sess: session during training occurs. :type current_state: numpy matrix (array of arrays) :param current_state: current hidden state :type return: (float, numpy matrix) :param return: (the calculated loss minibatch, updated hidden state) """ start_index = batch_num * self.settings.truncate end_index = start_index + self.settings.truncate batch_x = self.x_train_batches[:, start_index:end_index] batch_y = self.y_train_batches[:, start_index:end_index] total_loss, train_step, current_state, predictions_series = sess.run( [self.total_loss_fun, self.train_step_fun, self.current_state, self.predictions_series], feed_dict={ self.batch_x_placeholder:batch_x, self.batch_y_placeholder:batch_y, self.hidden_state:current_state }) return total_loss, current_state, predictions_series # end of __train_minibatch__() def __train_epoch__(self, epoch_num, sess, current_state, loss_list): """ trains 1 full epoch. :type epoch_num: int :param epoch_num: number of current epoch. :type sess: tensorflow session :param sess: session during training occurs. :type current_state: numpy matrix :param current_state: current hidden state. :type loss_list: list of floats :param loss_list: holds losses incurred during training. :type return: (float, numpy matrix) :param return: (the latest incurred lost, latest hidden state) """ self.logger.info("starting epoch: %d" % (epoch_num)) batch_num in range(self.num_batches): # debug log outside of function reduce number of arguments. self.logger.debug("training minibatch : ", batch_num, " | ", "epoch : ", epoch_num + 1) total_loss, current_state, predictions_series = self.__train_minibatch__(batch_num, sess, current_state) loss_list.append(total_loss) # end of batch training self.logger.info("finished epoch: %d | loss: %f" % (epoch_num, total_loss)) return total_loss, current_state, predictions_series # end of __train_epoch__() def train(self): """ trains given model on given dataset, , saves losses incurred @ end of each epoch plot image. """ self.logger.info("started training model.") self.__unstack_variables__() self.__create_functions__() tf.session() sess: sess.run(tf.global_variables_initializer()) loss_list = [] current_state = np.zeros((self.settings.batch_size, self.settings.hidden_size), dtype=float) epoch_idx in range(1, self.settings.epochs + 1): total_loss, current_state, predictions_series = self.__train_epoch__(epoch_idx, sess, current_state, loss_list) print("shape: ", current_state.shape, " | current output: ", current_state) # end of epoch training self.logger.info("finished training model. final loss: %f" % total_loss) self.__plot__(loss_list) self.generate_output() # end of train()
update
after completing second part of tutorial , using built-in rnn api, problem gone, means there's either wrong way use current_state
variable, or changes tensorflow api caused wacky happen (i'm pretty sure it's former, though). going leave question open in case has definitive answer.
first should make sure "it appears work" true , test error getting lower.
an hypothesis have last batch corrupted zeros @ end because length of data total_series_length / batch_size
not multiple of truncated_backprop_length
. (i did not check happen filled zeros. code in tutorial old run on tf version , don't have code.) final mini-batch zeros @ end lead final current_state
converge ones. on other mini-batch current_state
not ones.
you try printing current_state
each time run sess.run
, in __train_minibatch__
. or maybe print every 1000 mini-batches.
Comments
Post a Comment