57 lines
2.1 KiB
Python
57 lines
2.1 KiB
Python
|
from sklearn.linear_model import LinearRegression
|
||
|
from sklearn.model_selection import train_test_split
|
||
|
import statsmodels.api as sm
|
||
|
|
||
|
|
||
|
class OlsModel:
|
||
|
def __init__(self, x, y):
|
||
|
self.x = x
|
||
|
self.y = y
|
||
|
self.results = self.create_model()
|
||
|
|
||
|
def create_model(self):
|
||
|
x = sm.add_constant(self.x)
|
||
|
model = sm.OLS(self.y, x)
|
||
|
results = model.fit()
|
||
|
return results
|
||
|
|
||
|
|
||
|
class MlrModel:
|
||
|
def __init__(self, x, y):
|
||
|
self.x = x
|
||
|
self.y = y
|
||
|
self.results = self.create_model()
|
||
|
|
||
|
def create_model(self):
|
||
|
X_train, X_test, y_train, y_test = train_test_split(self.x, self.y, test_size=0.25, random_state=42)
|
||
|
model = LinearRegression().fit(X_train, y_train)
|
||
|
return model
|
||
|
|
||
|
|
||
|
def ols_calcutate_all(x, qufu_mean_ols_model, qufu_std_ols_model,
|
||
|
kangla_mean_ols_model, kangla_std_ols_model,
|
||
|
yanshen_mean_ols_model, yanshen_std_ols_model):
|
||
|
print("屈服均值: " + str(qufu_mean_ols_model.results.predict(x)) + "\n"
|
||
|
"抗拉均值: " + str(kangla_mean_ols_model.results.predict(x)) + "\n"
|
||
|
"延伸率均值: " + str(yanshen_mean_ols_model.results.predict(x)) + "\n"
|
||
|
"屈服标准差: " + str(qufu_std_ols_model.results.predict(x)) + "\n"
|
||
|
"抗拉标准差: " + str(kangla_std_ols_model.results.predict(x)) + "\n"
|
||
|
"延伸率标准差: " + str(yanshen_std_ols_model.results.predict(x)) + "\n"
|
||
|
)
|
||
|
|
||
|
|
||
|
def mlr_calcutate_all(x, qufu_mean_mlr_model, qufu_std_mlr_model,
|
||
|
kangla_mean_mlr_model, kangla_std_mlr_model,
|
||
|
yanshen_mean_mlr_model, yanshen_std_mlr_model):
|
||
|
print("屈服均值: " + str(qufu_mean_mlr_model.results.predict(x)) + "\n"
|
||
|
"抗拉均值: " + str(kangla_mean_mlr_model.results.predict(x)) + "\n"
|
||
|
"延伸率均值: " + str(yanshen_mean_mlr_model.results.predict(x)) + "\n"
|
||
|
"屈服标准差: " + str(qufu_std_mlr_model.results.predict(x)) + "\n"
|
||
|
"抗拉标准差: " + str(kangla_std_mlr_model.results.predict(x)) + "\n"
|
||
|
"延伸率标准差: " + str(yanshen_std_mlr_model.results.predict(x)) + "\n"
|
||
|
)
|
||
|
|
||
|
|
||
|
|
||
|
|