Sunday, June 27, 2021

Tensorflow Decision Forests

 #Introducing TensorFlow Decision Forests

#https://blog.tensorflow.org/2021/05/introducing-tensorflow-decision-forests.html


#next step : https://www.tensorflow.org/decision_forests/tutorials
################################################################################


#!pip install tensorflow_decision_forests 
#!wget "https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv"

# Load TensorFlow Decision Forests
import tensorflow_decision_forests as tfdf

# Load the training dataset using pandas
import pandas
df  = pandas.read_csv("penguins.csv")

from sklearn.model_selection import train_test_split 
train_df , test_df = train_test_split(df, test_size = 0.2)

# Convert the pandas dataframe into a TensorFlow dataset
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="species")

# Train the model
model = tfdf.keras.RandomForestModel()
model.fit(train_ds)


# Load the testing dataset
#test_df = pandas.read_csv("penguins_test.csv")

# Convert it to a TensorFlow dataset
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="species")

# Evaluate the model
model.compile(metrics=["accuracy"])
print(model.evaluate(test_ds))
# >> 0.979311
# Note: Cross-validation would be more suited on this small dataset.
# See also the "Out-of-bag evaluation" below.

# Export the model to a TensorFlow SavedModel
model.save("project/my_first_model")


tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0)


model.summary()



model.make_inspector().variable_importances()
#["MEAN_DECREASE_IN_ACCURACY"]
#MEAN_DECREASE_IN_ACCURACY

# List all the other available learning algorithms
tfdf.keras.get_all_models()

# Display the hyper-parameters of the Gradient Boosted Trees model 
#the following will open a new help window describing GradientBoostedTreesModel
#? tfdf.keras.GradientBoostedTreesModel

# Create another model with specified hyper-parameters
model = tfdf.keras.GradientBoostedTreesModel(
    num_trees=500,
    growing_strategy="BEST_FIRST_GLOBAL",
    max_depth=8,
    split_axis="SPARSE_OBLIQUE",
    )

# Evaluate the model
model.compile(metrics=["accuracy"])
print(model.evaluate(test_ds))

No comments:

Post a Comment

How to check local and global angular versions

 Use the command ng version (or ng v ) to find the version of Angular CLI in the current folder. Run it outside of the Angular project, to f...