CreateCustomSearch¶
Create a search with customized search settings. Validate the model against a test dataset and get actual and predicted values to be used for additional analysis.
This dataset uses experimental data from a double pendulum. Variables represent the x positions (x1 and x2), velocities (v1 and v2), and accelerations (a1 and a2) of the two arms at different points in time.
Create a connection to Eureqa:
from eureqa import Eureqa
from eureqa import error_metric
e = Eureqa(url="https://rds.nutonian.com", user_name="user@nutonian.com", password="password")
Create a data source:
data_source = e.create_data_source("Double Pendulum - Training Data", "double_pendulum_train.csv")
Initialize search settings with the “numeric” template. Target variable is “a2”, the acceleration of the second arm of the double pendulum:
variables = set(data_source.get_variables())
target_variable = "a2"
settings = e.search_templates.numeric("Model a2", target_variable, variables - {target_variable, 't'})
The double pendulum is a physical system. Customize the search settings to disable irrelevant building blocks like if-then-else, and to enable those that are relevant like the trigonometric operators:
settings.math_blocks.const.complexity = 1
settings.math_blocks.if_op.disable() # disable if-then-else
settings.math_blocks.less.disable() # disable less
settings.math_blocks.sin.enable(3) # enable sine and set complexity to 3
settings.math_blocks.cos.enable(3) # enable cosine and set complexity to 3
settings.error_metric = error_metric.mean_square_error()
Create a search and run for 30 seconds:
search = data_source.create_search(settings)
search.submit(30)
search.wait_until_done()
Get the best model, view the model and the error metrics:
solution = search.get_best_solution()
print("The best model found is:\n %s = %s" % (solution.target, solution.model))
The best model found is:
a2 = -0.0239994769615275 - a1*cos(x2 - x1) - v1^2*sin(x2 - x1) - 9.81639183106058*sin(x2)
Get the model performance:
print("The %s value for this search is %.2f" % (solution.optimized_error_metric, solution.optimized_error_metric_value))
The Mean Squared Error value for this search is 0.49
Evaluate the model against a test dataset withheld from Eureqa:
test_data_source = e.create_data_source("Double Pendulum - Test Data", "double_pendulum_test.csv")
test_metrics = e.compute_error_metrics(test_data_source, target_variable, solution.model)
test_mse = test_metrics.mean_square_error
print("The %s value for the test data is %.2f" % (solution.optimized_error_metric, test_mse))
The Mean Squared Error value for the test data is 1.34
Get the actual and predicted values for a test set and load into a DataFrame for future analysis:
predicted_and_actual = e.evaluate_expression(test_data_source, ["t", 'a2', solution.model])
import pandas as pd
df = pd.DataFrame(predicted_and_actual)
df
-0.0239994769615275 - a1*cos(x2 - x1) - v1^2*sin(x2 - x1) - 9.81639183106058*sin(x2) | a2 | t | |
---|---|---|---|
0 | -44.613217 | -42.80 | 7.14 |
1 | -45.778557 | -45.10 | 7.14 |
2 | -47.013552 | -47.30 | 7.15 |
3 | -51.071663 | -49.50 | 7.15 |
4 | -52.363975 | -51.70 | 7.15 |
5 | -53.608793 | -53.80 | 7.15 |
6 | -54.806116 | -55.80 | 7.16 |
7 | -58.057057 | -57.80 | 7.16 |
8 | -59.106151 | -59.50 | 7.16 |
9 | -59.895165 | -60.90 | 7.16 |
10 | -60.382955 | -61.80 | 7.17 |
11 | -61.704556 | -62.20 | 7.17 |
12 | -62.667552 | -62.20 | 7.17 |
13 | -62.072153 | -61.80 | 7.17 |
14 | -61.063694 | -60.80 | 7.17 |
15 | -58.765531 | -59.20 | 7.18 |
16 | -56.596001 | -56.90 | 7.18 |
17 | -53.955954 | -53.90 | 7.18 |
18 | -50.749857 | -50.10 | 7.18 |
19 | -47.745222 | -46.50 | 7.18 |
20 | -42.061859 | -42.30 | 7.18 |
21 | -38.071695 | -37.70 | 7.19 |
22 | -33.788941 | -32.60 | 7.19 |
23 | -27.042389 | -27.40 | 7.19 |
24 | -22.409397 | -22.00 | 7.19 |
25 | -13.933737 | -16.30 | 7.19 |
26 | -8.833737 | -10.30 | 7.19 |
27 | -3.433737 | -3.92 | 7.19 |
28 | 2.096263 | 2.55 | 7.20 |
29 | 7.598263 | 9.02 | 7.20 |
... | ... | ... | ... |
781 | -7.208298 | -7.15 | 9.75 |
782 | -6.753226 | -7.02 | 9.76 |
783 | -6.712980 | -6.89 | 9.77 |
784 | -6.672694 | -6.76 | 9.77 |
785 | -6.632427 | -6.64 | 9.78 |
786 | -6.589876 | -6.53 | 9.79 |
787 | -6.200674 | -6.42 | 9.80 |
788 | -6.175753 | -6.31 | 9.81 |
789 | -6.148107 | -6.21 | 9.82 |
790 | -6.117373 | -6.11 | 9.82 |
791 | -5.761961 | -6.01 | 9.83 |
792 | -5.742415 | -5.90 | 9.84 |
793 | -5.718649 | -5.80 | 9.85 |
794 | -5.688919 | -5.69 | 9.86 |
795 | -5.652192 | -5.58 | 9.87 |
796 | -5.329606 | -5.47 | 9.88 |
797 | -5.297914 | -5.35 | 9.89 |
798 | -5.257292 | -5.22 | 9.89 |
799 | -5.208666 | -5.09 | 9.90 |
800 | -5.151411 | -4.94 | 9.91 |
801 | -4.837154 | -4.79 | 9.92 |
802 | -4.777049 | -4.63 | 9.93 |
803 | -4.705135 | -4.45 | 9.94 |
804 | -4.625141 | -4.26 | 9.95 |
805 | -3.821506 | -4.06 | 9.96 |
806 | -3.706753 | -3.84 | 9.97 |
807 | -3.382900 | -3.64 | 9.98 |
808 | -3.284431 | -3.42 | 9.98 |
809 | -3.155859 | -3.19 | 9.99 |
810 | -3.020318 | -2.95 | 10.00 |
811 rows × 3 columns