Source code for tests.test_rnn

import pytest
from src.StockRNN import StockRNN
import numpy as np
from datetime import datetime


@pytest.fixture()
def _stock_rnn():
    r"""
    TODO: documentation
    """
    return StockRNN("IBM", 1, 1, train_start_date=datetime(2017, 1, 1), train_end_date=datetime(2018, 1, 1))


[docs]def test_populate_daily_stock_data(_stock_rnn: StockRNN): r""" TODO: documentation """ assert len(_stock_rnn.daily_stock_data) % _stock_rnn.sequence_segment_length == 0
[docs]def test_populate_test_train_creates_correct_number_of_randomly_ordered_segments(_stock_rnn: StockRNN): r""" TODO: documentation """ _stock_rnn.populate_daily_stock_data() _stock_rnn.populate_test_train(rand_seed=1) assert np.array_equal(_stock_rnn.test_sample_indices, np.array([5, 8, 9, 11, 12], dtype=np.int64)) assert np.array_equal(_stock_rnn.train_sample_indices, np.array([14, 13, 17, 3, 21, 10, 18, 19, 4, 2, 20, 6, 7, 22, 1, 16, 0, 15, 24, 23])) assert _stock_rnn.train_set.__len__() == 20 assert _stock_rnn.test_set.__len__() == 5
# TODO: add more tests! if __name__ == "__main__": pytest.main()