Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
language: rust
sudo: false
dist: trusty # still in beta, but required for the prebuilt TF binaries

cache:
cargo: true
directories:
- $HOME/.cache/bazel

rust: nightly
rust: stable

install:
- export CC="gcc-4.9" CXX="g++-4.9"
- source travis-ci/install.sh

script:
- export RUST_BACKTRACE=1
- cargo test -vv -j 2 --features tensorflow_unstable
- cargo run --example regression
- cargo run --features tensorflow_unstable --example expressions
- cargo doc -vv --features tensorflow_unstable
- (cd tensorflow-sys && cargo test -vv -j 1)
- # TODO(#66): Re-enable: (cd tensorflow-sys && cargo test -vv -j 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that required in the CI? I couldn't find anybody that could reproduce #66...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes; this test fails on my local machine and on Travis: https://travis-ci.org/tensorflow/rust/builds/208479366

- (cd tensorflow-sys && cargo doc -vv)

addons:
Expand Down
6 changes: 6 additions & 0 deletions tensorflow-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,11 @@ links = "tensorflow"
libc = "0.2"

[build-dependencies]
curl = "0.4"
flate2 = "0.2"
pkg-config = "0.3"
semver = "0.5"
tar = "0.4"

[features]
tensorflow_gpu = []
94 changes: 93 additions & 1 deletion tensorflow-sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
extern crate curl;
extern crate flate2;
extern crate pkg_config;
extern crate semver;
extern crate tar;

use std::error::Error;
use std::fs::File;
use std::io::BufWriter;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::process;
use std::process::Command;
use std::{env, fs};

use curl::easy::Easy;
use flate2::read::GzDecoder;
use semver::Version;
use tar::Archive;

const LIBRARY: &'static str = "tensorflow";
const REPOSITORY: &'static str = "https://github.com/tensorflow/tensorflow.git";
const TARGET: &'static str = "tensorflow:libtensorflow.so";
// `VERSION` and `TAG` are separate because the tag is not always `'v' + VERSION`.
const VERSION: &'static str = "1.0.0";
const TAG: &'static str = "v1.0.0";
const MIN_BAZEL: &'static str = "0.3.2";

Expand All @@ -30,6 +41,84 @@ fn main() {
return;
}

if env::consts::ARCH == "x86_64" && (env::consts::OS == "linux" || env::consts::OS == "macos") {
install_prebuilt();
} else {
build_from_src();
}
}

fn remove_suffix(value: &mut String, suffix: &str) {
if value.ends_with(suffix) {
let n = value.len();
value.truncate(n - suffix.len());
}
}

fn extract<P: AsRef<Path>, P2: AsRef<Path>>(archive_path: P, extract_to: P2) {
let file = File::open(archive_path).unwrap();
let unzipped = GzDecoder::new(file).unwrap();
let mut a = Archive::new(unzipped);
a.unpack(extract_to).unwrap();
}

// Downloads and unpacks a prebuilt binary. Only works for certain platforms.
fn install_prebuilt() {
// Figure out the file names.
let os = match env::consts::OS {
"macos" => "darwin",
x => x,
};
let proc_type = if cfg!(feature = "tensorflow_gpu") {"gpu"} else {"cpu"};
let binary_url = format!(
"https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-{}-{}-{}-{}.tar.gz",
proc_type, os, env::consts::ARCH, VERSION);
log_var!(binary_url);
let short_file_name = binary_url.split("/").last().unwrap();
let mut base_name = short_file_name.to_string();
remove_suffix(&mut base_name, ".tar.gz");
log_var!(base_name);
let target_dir = PathBuf::from(&get!("CARGO_MANIFEST_DIR")).join("target");
if !target_dir.exists() {
fs::create_dir(&target_dir).unwrap();
}
let file_name = target_dir.join(short_file_name);
log_var!(file_name);

// Download the tarball.
if !file_name.exists() {
let f = File::create(&file_name).unwrap();
let mut writer = BufWriter::new(f);
let mut easy = Easy::new();
easy.url(&binary_url).unwrap();
easy.write_function(move |data| {
Ok(writer.write(data).unwrap())
}).unwrap();
easy.perform().unwrap();

let response_code = easy.response_code().unwrap();
if response_code != 200 {
panic!("Unexpected response code {} for {}", response_code, binary_url);
}
}

// Extract the tarball.
let unpacked_dir = target_dir.join(base_name);
let lib_dir = unpacked_dir.join("lib");
if !lib_dir.join(format!("lib{}.so", LIBRARY)).exists() {
extract(file_name, &unpacked_dir);
}

//run("find", |command| command); // TODO: remove
run("ls", |command| {
command.arg("-l").arg(lib_dir.to_str().unwrap())
}); // TODO: remove

println!("cargo:rustc-link-lib=dylib={}", LIBRARY);
println!("cargo:rustc-link-search={}", lib_dir.display());
}

fn build_from_src() {
let output = PathBuf::from(&get!("OUT_DIR"));
log_var!(output);
let source = PathBuf::from(&get!("CARGO_MANIFEST_DIR")).join(format!("target/source-{}", TAG));
Expand Down Expand Up @@ -71,7 +160,10 @@ fn main() {
let configure_hint_file = Path::new(&configure_hint_file_pb);
if !configure_hint_file.exists() {
run("bash",
|command| command.current_dir(&source).arg("-c").arg("yes ''|./configure"));
|command| command.current_dir(&source)
.env("TF_NEED_CUDA", if cfg!(feature = "tensorflow_gpu") {"1"} else {"0"})
.arg("-c")
.arg("yes ''|./configure"));
File::create(configure_hint_file).unwrap();
}
run("bazel", |command| {
Expand Down