请问代码训练好了如何具体实战?

来源:4-36 动手实现RNN-LSTM循环神经网络(十一):实际训练和测试

永不止息L

2019-01-01

老师,比如我现在有如下一段用来做股票价格预测的代码:

import pandas
import numpy
from keras.layers.core import Dense, Activation, Dropout
from keras.layers.recurrent import LSTM
from keras.models import Sequential
import matplotlib.pyplot as plt

CONST_TRAINING_SEQUENCE_LENGTH = 12
CONST_TESTING_CASES = 5


def dataNormalization(data):
    return [(datum - data[0]) / data[0] for datum in data]


def dataDeNormalization(data, base):
    return [(datum + 1) * base for datum in data]


def getDeepLearningData(ticker):
    # Step 1. Load data
    data = pandas.read_csv('./data/Intraday/' + ticker + '.csv')[
        'close'].tolist()
    # Step 2. Building Training data
    dataTraining = []
    for i in range(len(data) - CONST_TESTING_CASES * CONST_TRAINING_SEQUENCE_LENGTH):
        dataSegment = data[i:i + CONST_TRAINING_SEQUENCE_LENGTH + 1]
        dataTraining.append(dataNormalization(dataSegment))

    dataTraining = numpy.array(dataTraining)
    numpy.random.shuffle(dataTraining)
    X_Training = dataTraining[:, :-1]
    Y_Training = dataTraining[:, -1]

    # Step 3. Building Testing data
    X_Testing = []
    Y_Testing_Base = []
    for i in range(CONST_TESTING_CASES, 0, -1):
        dataSegment = data[-(i + 1) * CONST_TRAINING_SEQUENCE_LENGTH:-i * CONST_TRAINING_SEQUENCE_LENGTH]
        Y_Testing_Base.append(dataSegment[0])
        X_Testing.append(dataNormalization(dataSegment))

    Y_Testing = data[-CONST_TESTING_CASES * CONST_TRAINING_SEQUENCE_LENGTH:]

    X_Testing = numpy.array(X_Testing)
    Y_Testing = numpy.array(Y_Testing)

    # Step 4. Reshape for deep learning
    X_Training = numpy.reshape(X_Training, (X_Training.shape[0], X_Training.shape[1], 1))
    X_Testing = numpy.reshape(X_Testing, (X_Testing.shape[0], X_Testing.shape[1], 1))

    return X_Training, Y_Training, X_Testing, Y_Testing, Y_Testing_Base


def predict(model, X):
    predictionsNormalized = []

    for i in range(len(X)):
        data = X[i]
        result = []

        for j in range(CONST_TRAINING_SEQUENCE_LENGTH):
            predicted = model.predict(data[numpy.newaxis, :, :])[0, 0]
            result.append(predicted)
            data = data[1:]
            data = numpy.insert(data, [CONST_TRAINING_SEQUENCE_LENGTH - 1], predicted, axis=0)

        predictionsNormalized.append(result)

    return predictionsNormalized


def plotResults(Y_Hat, Y):
    plt.plot(Y)

    for i in range(len(Y_Hat)):
        padding = [None for _ in range(i * CONST_TRAINING_SEQUENCE_LENGTH)]
        plt.plot(padding + Y_Hat[i])

    plt.show()


def predictLSTM(ticker):
    # Step 1. Load data
    X_Training, Y_Training, X_Testing, Y_Testing, Y_Testing_Base = getDeepLearningData(ticker)

    # Step 2. Build model
    model = Sequential()

    model.add(LSTM(
        input_shape=(None, 1),
        units=50,
        return_sequences=True))
    model.add(Dropout(0.2))

    model.add(LSTM(
        200,
        return_sequences=False))
    model.add(Dropout(0.2))

    model.add(Dense(units=1))
    model.add(Activation('linear'))

    model.compile(loss='mse', optimizer='rmsprop')

    # Step 3. Train model
    model.fit(X_Training, Y_Training,
              batch_size=512,
              epochs=27,
              validation_split=0.05)

    # Step 4. Predict
    predictionsNormalized = predict(model, X_Testing)

    # Step 5. De-nomalize
    predictions = []
    for i, row in enumerate(predictionsNormalized):
        predictions.append(dataDeNormalization(row, Y_Testing_Base[i]))

    # Step 6. Plot
    plotResults(predictions, Y_Testing)


predictLSTM(ticker='IBM')

这段代码只是在通过旧的数据来训练模型,如果我要把这个应用在实战,应该修改哪几个部分呢?因为我所有接触的教程都只是用旧的数据来测试,没有看到一个直接预测未来数据的,所以这一块不清楚。十分感谢!

写回答

1回答

Oscar

2019-01-13

RNN神经网络做预测需要基于旧数据来训练出模型(找到一定规律)才预测未来的。你如果直接预测那为什么还训练模型呢?

0
1
永不止息L
我这个训练模型已经用了旧的数据了,# Step 1. Load data data = pandas.read_csv('./data/Intraday/' + ticker + '.csv')[ 'close'].tolist() # Step 2. Building Training data 这里不是已经在用数据了吗?我的问题是应该怎么用已经训练好的模型去实际预测。目前我知道的大概是要先保存再调用是这样的吗?老师,你懂这段代码吗?
2019-01-13
共1条回复

基于Python玩转人工智能最火框架 TensorFlow应用实践

机器学习入门,打牢TensorFlow框架应用是关键!

2214 学习 · 688 问题

查看课程