import solution
import random
import numpy as np
iterations = 50
num_of_topics = 20
alpha = 0.1
gamma = 0.1
random.seed(42)
np.random.seed(42)
train_docs, dictionary, train_newsgroups = solution.load_and_preprocess('train')
num_of_docs = len(train_docs)
num_of_words = len(dictionary)
lda_result = solution.lda(docs=train_docs, num_of_topics=num_of_topics, num_of_words=num_of_words, iterations=iterations, alpha=alpha, gamma=gamma)
topics, saved_topics_distribution, entropies, doc_topics_count, word_topics_count, topics_count = lda_result
solution.plot_distribution_over_topics(saved_topics_distribution, num_of_topics)
solution.plot_entropies(entropies)
solution.plot_word_histogram(selected_topics=[1, 5, 15], word_topics_count=word_topics_count, n=20, dictionary=dictionary)
test_docs, _, test_newsgroups = solution.load_and_preprocess('test', dictionary=dictionary)
lda_new_data_results = solution.lda_new_data(docs=test_docs, num_of_topics=num_of_topics, num_of_words=num_of_words,
word_topics_count=word_topics_count, topics_count=topics_count,
iterations=iterations, alpha=alpha, gamma=gamma)
_, lda_perplexity, simple_perplexity = lda_new_data_results
print(f'{lda_perplexity=}, {simple_perplexity=}')
iterations = 20
num_of_topics = 10
alpha = 0.01
gamma = 0.01
random.seed(42)
np.random.seed(42)
lda_result_2 = solution.lda(docs=train_docs, num_of_topics=num_of_topics, num_of_words=num_of_words, iterations=iterations, alpha=alpha, gamma=gamma)
topics_2, saved_topics_distribution_2, entropies_2, doc_topics_count_2, word_topics_count_2, topics_count_2 = lda_result_2
solution.plot_distribution_over_topics(saved_topics_distribution_2, num_of_topics, name='topics_distribution_2.png')
solution.plot_entropies(entropies_2, name='entropies_2.png')
solution.plot_word_histogram(selected_topics=[1, 5, 9], word_topics_count=word_topics_count_2, n=20, dictionary=dictionary, name='word_histogram_2.png')
lda_new_data_results_2 = solution.lda_new_data(docs=test_docs, num_of_topics=num_of_topics, num_of_words=num_of_words,
word_topics_count=word_topics_count_2, topics_count=topics_count_2,
iterations=iterations, alpha=alpha, gamma=gamma)
_, lda_perplexity_2, simple_perplexity_2 = lda_new_data_results_2
print(f'{lda_perplexity_2=}, {simple_perplexity_2=}')